feat(config): simplify config migrator logic

- Remove `Migrations` class - unnecessary complexity on top of `MigrationEntry`
- Move common classes to `config_common`
- Tidy docstrings, variable names
This commit is contained in:
psychedelicious 2024-05-14 16:20:40 +10:00
parent fc23b16a73
commit 6946a3871f
3 changed files with 137 additions and 141 deletions

View File

@ -12,6 +12,10 @@ from __future__ import annotations
import argparse import argparse
import pydoc import pydoc
from dataclasses import dataclass
from typing import Any, Callable, TypeAlias
from packaging.version import Version
class PagingArgumentParser(argparse.ArgumentParser): class PagingArgumentParser(argparse.ArgumentParser):
@ -23,3 +27,21 @@ class PagingArgumentParser(argparse.ArgumentParser):
def print_help(self, file=None) -> None: def print_help(self, file=None) -> None:
text = self.format_help() text = self.format_help()
pydoc.pager(text) pydoc.pager(text)
AppConfigDict: TypeAlias = dict[str, Any]
MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
@dataclass
class MigrationEntry:
"""Defines an individual migration."""
from_version: Version
to_version: Version
function: MigrationFunction
def __hash__(self) -> int:
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
return hash((self.from_version, self.to_version))

View File

@ -6,60 +6,38 @@ Utility class for migrating among versions of the InvokeAI app config schema.
import locale import locale
import shutil import shutil
from dataclasses import dataclass from copy import deepcopy
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Callable, List, TypeAlias
import yaml import yaml
from packaging.version import Version from packaging.version import Version
import invokeai.configs as model_configs import invokeai.configs as model_configs
from invokeai.app.services.config.config_common import AppConfigDict, MigrationEntry
from invokeai.app.services.config.migrations import config_migration_1, config_migration_2
from invokeai.frontend.cli.arg_parser import InvokeAIArgs from invokeai.frontend.cli.arg_parser import InvokeAIArgs
from .config_default import CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, URLRegexTokenPair from .config_default import CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, URLRegexTokenPair
from .migrations import AppConfigDict, Migrations, MigrationsBase
MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
@dataclass
class MigrationEntry:
"""Defines an individual migration."""
from_version: Version
to_version: Version
function: MigrationFunction
class ConfigMigrator: class ConfigMigrator:
"""This class allows migrators to register their input and output versions.""" """This class allows migrators to register their input and output versions."""
def __init__(self, migrations: type[MigrationsBase]) -> None: def __init__(self) -> None:
self._migrations: List[MigrationEntry] = [] self._migrations: set[MigrationEntry] = set()
migrations.load(self)
def register( def register(self, migration: MigrationEntry) -> None:
self, migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations)
from_version: str, migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations)
to_version: str, if migration_from_already_registered or migration_to_already_registered:
) -> Callable[[MigrationFunction], MigrationFunction]:
"""Define a decorator which registers the migration between two versions."""
def decorator(function: MigrationFunction) -> MigrationFunction:
if any((from_version == m.from_version) or (to_version == m.to_version) for m in self._migrations):
raise ValueError( raise ValueError(
f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered." f"A migration from {migration.from_version} or to {migration.to_version} has already been registered."
) )
self._migrations.append( self._migrations.add(migration)
MigrationEntry(from_version=Version(from_version), to_version=Version(to_version), function=function)
)
return function
return decorator
@staticmethod @staticmethod
def _check_for_discontinuities(migrations: List[MigrationEntry]) -> None: def _check_for_discontinuities(migrations: list[MigrationEntry]) -> None:
current_version = Version("3.0.0") current_version = Version("3.0.0")
for m in migrations: for m in migrations:
if current_version != m.from_version: if current_version != m.from_version:
@ -68,35 +46,38 @@ class ConfigMigrator:
) )
current_version = m.to_version current_version = m.to_version
def run_migrations(self, config_dict: AppConfigDict) -> AppConfigDict: def run_migrations(self, original_config: AppConfigDict) -> AppConfigDict:
""" """
Use the registered migration steps to bring config up to latest version. Use the registered migrations to bring config up to latest version.
:param config: The original configuration. Args:
:return: The new configuration, lifted up to the latest version. original_config: The original configuration.
As a side effect, the new configuration will be written to disk. Returns:
If an inconsistency in the registered migration steps' `from_version` The new configuration, lifted up to the latest version.
and `to_version` parameters are identified, this will raise a
ValueError exception.
""" """
# Sort migrations by version number and raise a ValueError if
# any version range overlaps are detected. # Sort migrations by version number and raise a ValueError if any version range overlaps are detected.
sorted_migrations = sorted(self._migrations, key=lambda x: x.from_version) sorted_migrations = sorted(self._migrations, key=lambda x: x.from_version)
self._check_for_discontinuities(sorted_migrations) self._check_for_discontinuities(sorted_migrations)
if "InvokeAI" in config_dict: # Do not mutate the incoming dict - we don't know who else may be using it
migrated_config = deepcopy(original_config)
# v3.0.0 configs did not have "schema_version", but did have "InvokeAI"
if "InvokeAI" in migrated_config:
version = Version("3.0.0") version = Version("3.0.0")
else: else:
version = Version(config_dict["schema_version"]) version = Version(migrated_config["schema_version"])
for migration in sorted_migrations: for migration in sorted_migrations:
if version == migration.from_version and version < migration.to_version: if version == migration.from_version:
config_dict = migration.function(config_dict) migrated_config = migration.function(migrated_config)
version = migration.to_version version = migration.to_version
config_dict["schema_version"] = str(version) # We must end on the latest version
return config_dict assert migrated_config["schema_version"] == str(sorted_migrations[-1].to_version)
return migrated_config
@lru_cache(maxsize=1) @lru_cache(maxsize=1)
@ -165,14 +146,16 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
""" """
assert config_path.suffix == ".yaml" assert config_path.suffix == ".yaml"
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file: with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
loaded_config_dict = yaml.safe_load(file) loaded_config_dict: AppConfigDict = yaml.safe_load(file)
assert isinstance(loaded_config_dict, dict) assert isinstance(loaded_config_dict, dict)
shutil.copy(config_path, config_path.with_suffix(".yaml.bak")) shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try: try:
migrator = ConfigMigrator(Migrations) migrator = ConfigMigrator()
migrated_config_dict = migrator.run_migrations(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType] migrator.register(config_migration_1)
migrator.register(config_migration_2)
migrated_config_dict = migrator.run_migrations(loaded_config_dict)
except Exception as e: except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path) shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e

View File

@ -10,65 +10,46 @@ To define a new migration, add a migration function to
Migrations.load_migrations() following the existing examples. Migrations.load_migrations() following the existing examples.
""" """
from abc import ABC, abstractmethod
from pathlib import Path from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeAlias
from packaging.version import Version
from invokeai.app.services.config.config_common import AppConfigDict, MigrationEntry
from .config_default import InvokeAIAppConfig from .config_default import InvokeAIAppConfig
if TYPE_CHECKING:
from .config_migrate import ConfigMigrator
AppConfigDict: TypeAlias = dict[str, Any] def migrate_v300_to_v400(original_config: AppConfigDict) -> AppConfigDict:
"""Migrate a v3.0.0 config dict to v4.0.0.
Changes in this migration:
class MigrationsBase(ABC): - `outdir` was renamed to `outputs_dir`
"""Define the config file migration steps to apply, abstract base class.""" - `max_cache_size` was renamed to `ram`
- `max_vram_cache_size` was renamed to `vram`
@classmethod - `conf_path`, which pointed to the old `models.yaml`, was removed - but if need to stash it to migrate the entries
@abstractmethod to the database
def load(cls, migrator: "ConfigMigrator") -> None: - `legacy_conf_dir` was changed from a path relative to the app root, to a path relative to $INVOKEAI_ROOT/configs
"""Use the provided migrator to register the configuration migrations to be run."""
class Migrations(MigrationsBase):
"""Configuration migration steps to apply."""
@classmethod
def load(cls, migrator: "ConfigMigrator") -> None:
"""Define migrations to perform."""
##################
# 3.0.0 -> 4.0.0 #
##################
@migrator.register(from_version="3.0.0", to_version="4.0.0")
def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate a v3 config dictionary to a current config object.
Args: Args:
config_dict: A dictionary of settings from a v3 config file. config_dict: The v3.0.0 config dict to migrate.
Returns: Returns:
A dictionary of settings from a 4.0.0 config file. The migrated v4.0.0 config dict.
""" """
parsed_config_dict: dict[str, Any] = {} migrated_config: AppConfigDict = {}
for _category_name, category_dict in config_dict["InvokeAI"].items(): for _category_name, category_dict in original_config["InvokeAI"].items():
for k, v in category_dict.items(): for k, v in category_dict.items():
# `outdir` was renamed to `outputs_dir` in v4 # `outdir` was renamed to `outputs_dir` in v4
if k == "outdir": if k == "outdir":
parsed_config_dict["outputs_dir"] = v migrated_config["outputs_dir"] = v
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used # `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict: if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v migrated_config["ram"] = v
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used # `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict: if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v migrated_config["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
if k == "conf_path": if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v migrated_config["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir": if k == "legacy_conf_dir":
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows). # The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion": if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
@ -76,33 +57,43 @@ class Migrations(MigrationsBase):
continue continue
elif Path(v).name == "stable-diffusion": elif Path(v).name == "stable-diffusion":
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path. # Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent) migrated_config["legacy_conf_dir"] = str(Path(v).parent)
else: else:
# Else we do not attempt to migrate this setting # Else we do not attempt to migrate this setting
parsed_config_dict["legacy_conf_dir"] = v migrated_config["legacy_conf_dir"] = v
elif k in InvokeAIAppConfig.model_fields: elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields # skip unknown fields
parsed_config_dict[k] = v migrated_config[k] = v
return parsed_config_dict migrated_config["schema_version"] = "4.0.0"
return migrated_config
##################
# 4.0.0 -> 4.0.1 # config_migration_1 = MigrationEntry(
################## from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migrate_v300_to_v400
@migrator.register(from_version="4.0.0", to_version="4.0.1") )
def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate v4.0.0 config dictionary to v4.0.1.
def migrate_v400_to_v401(original_config: AppConfigDict) -> AppConfigDict:
"""Migrate a v4.0.0 config dict to v4.0.1.
Changes in this migration:
- `precision: "autocast"` was removed, fall back to "auto"
Args: Args:
config_dict: A dictionary of settings from a v4.0.0 config file. config_dict: The v4.0.0 config dict to migrate.
Returns: Returns:
A dictionary of settings from a v4.0.1 config file The migrated v4.0.1 config dict.
""" """
parsed_config_dict: dict[str, Any] = {} migrated_config: AppConfigDict = {}
for k, v in config_dict.items(): for k, v in original_config.items():
# autocast was removed from precision in v4.0.1 # autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast": if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto" migrated_config["precision"] = "auto"
else: migrated_config["schema_version"] = "4.0.1"
parsed_config_dict[k] = v return migrated_config
return parsed_config_dict
config_migration_2 = MigrationEntry(
from_version=Version("4.0.0"), to_version=Version("4.0.1"), function=migrate_v400_to_v401
)