feat(config): simplify config migrator logic

- Remove `Migrations` class - unnecessary complexity on top of `MigrationEntry`
- Move common classes to `config_common`
- Tidy docstrings, variable names
This commit is contained in:
psychedelicious 2024-05-14 16:20:40 +10:00
parent fc23b16a73
commit 6946a3871f
3 changed files with 137 additions and 141 deletions

View File

@ -12,6 +12,10 @@ from __future__ import annotations
import argparse
import pydoc
from dataclasses import dataclass
from typing import Any, Callable, TypeAlias
from packaging.version import Version
class PagingArgumentParser(argparse.ArgumentParser):
@ -23,3 +27,21 @@ class PagingArgumentParser(argparse.ArgumentParser):
def print_help(self, file=None) -> None:
text = self.format_help()
pydoc.pager(text)
AppConfigDict: TypeAlias = dict[str, Any]
MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
@dataclass
class MigrationEntry:
"""Defines an individual migration."""
from_version: Version
to_version: Version
function: MigrationFunction
def __hash__(self) -> int:
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
return hash((self.from_version, self.to_version))

View File

@ -6,60 +6,38 @@ Utility class for migrating among versions of the InvokeAI app config schema.
import locale
import shutil
from dataclasses import dataclass
from copy import deepcopy
from functools import lru_cache
from pathlib import Path
from typing import Callable, List, TypeAlias
import yaml
from packaging.version import Version
import invokeai.configs as model_configs
from invokeai.app.services.config.config_common import AppConfigDict, MigrationEntry
from invokeai.app.services.config.migrations import config_migration_1, config_migration_2
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
from .config_default import CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, URLRegexTokenPair
from .migrations import AppConfigDict, Migrations, MigrationsBase
MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
@dataclass
class MigrationEntry:
"""Defines an individual migration."""
from_version: Version
to_version: Version
function: MigrationFunction
class ConfigMigrator:
"""This class allows migrators to register their input and output versions."""
def __init__(self, migrations: type[MigrationsBase]) -> None:
self._migrations: List[MigrationEntry] = []
migrations.load(self)
def __init__(self) -> None:
self._migrations: set[MigrationEntry] = set()
def register(
self,
from_version: str,
to_version: str,
) -> Callable[[MigrationFunction], MigrationFunction]:
"""Define a decorator which registers the migration between two versions."""
def decorator(function: MigrationFunction) -> MigrationFunction:
if any((from_version == m.from_version) or (to_version == m.to_version) for m in self._migrations):
raise ValueError(
f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered."
)
self._migrations.append(
MigrationEntry(from_version=Version(from_version), to_version=Version(to_version), function=function)
def register(self, migration: MigrationEntry) -> None:
migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations)
migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations)
if migration_from_already_registered or migration_to_already_registered:
raise ValueError(
f"A migration from {migration.from_version} or to {migration.to_version} has already been registered."
)
return function
return decorator
self._migrations.add(migration)
@staticmethod
def _check_for_discontinuities(migrations: List[MigrationEntry]) -> None:
def _check_for_discontinuities(migrations: list[MigrationEntry]) -> None:
current_version = Version("3.0.0")
for m in migrations:
if current_version != m.from_version:
@ -68,35 +46,38 @@ class ConfigMigrator:
)
current_version = m.to_version
def run_migrations(self, config_dict: AppConfigDict) -> AppConfigDict:
def run_migrations(self, original_config: AppConfigDict) -> AppConfigDict:
"""
Use the registered migration steps to bring config up to latest version.
Use the registered migrations to bring config up to latest version.
:param config: The original configuration.
:return: The new configuration, lifted up to the latest version.
Args:
original_config: The original configuration.
As a side effect, the new configuration will be written to disk.
If an inconsistency in the registered migration steps' `from_version`
and `to_version` parameters are identified, this will raise a
ValueError exception.
Returns:
The new configuration, lifted up to the latest version.
"""
# Sort migrations by version number and raise a ValueError if
# any version range overlaps are detected.
# Sort migrations by version number and raise a ValueError if any version range overlaps are detected.
sorted_migrations = sorted(self._migrations, key=lambda x: x.from_version)
self._check_for_discontinuities(sorted_migrations)
if "InvokeAI" in config_dict:
# Do not mutate the incoming dict - we don't know who else may be using it
migrated_config = deepcopy(original_config)
# v3.0.0 configs did not have "schema_version", but did have "InvokeAI"
if "InvokeAI" in migrated_config:
version = Version("3.0.0")
else:
version = Version(config_dict["schema_version"])
version = Version(migrated_config["schema_version"])
for migration in sorted_migrations:
if version == migration.from_version and version < migration.to_version:
config_dict = migration.function(config_dict)
if version == migration.from_version:
migrated_config = migration.function(migrated_config)
version = migration.to_version
config_dict["schema_version"] = str(version)
return config_dict
# We must end on the latest version
assert migrated_config["schema_version"] == str(sorted_migrations[-1].to_version)
return migrated_config
@lru_cache(maxsize=1)
@ -165,14 +146,16 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""
assert config_path.suffix == ".yaml"
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
loaded_config_dict = yaml.safe_load(file)
loaded_config_dict: AppConfigDict = yaml.safe_load(file)
assert isinstance(loaded_config_dict, dict)
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try:
migrator = ConfigMigrator(Migrations)
migrated_config_dict = migrator.run_migrations(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
migrator = ConfigMigrator()
migrator.register(config_migration_1)
migrator.register(config_migration_2)
migrated_config_dict = migrator.run_migrations(loaded_config_dict)
except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e

View File

@ -10,99 +10,90 @@ To define a new migration, add a migration function to
Migrations.load_migrations() following the existing examples.
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeAlias
from packaging.version import Version
from invokeai.app.services.config.config_common import AppConfigDict, MigrationEntry
from .config_default import InvokeAIAppConfig
if TYPE_CHECKING:
from .config_migrate import ConfigMigrator
AppConfigDict: TypeAlias = dict[str, Any]
def migrate_v300_to_v400(original_config: AppConfigDict) -> AppConfigDict:
"""Migrate a v3.0.0 config dict to v4.0.0.
Changes in this migration:
- `outdir` was renamed to `outputs_dir`
- `max_cache_size` was renamed to `ram`
- `max_vram_cache_size` was renamed to `vram`
- `conf_path`, which pointed to the old `models.yaml`, was removed - but if need to stash it to migrate the entries
to the database
- `legacy_conf_dir` was changed from a path relative to the app root, to a path relative to $INVOKEAI_ROOT/configs
class MigrationsBase(ABC):
"""Define the config file migration steps to apply, abstract base class."""
Args:
config_dict: The v3.0.0 config dict to migrate.
@classmethod
@abstractmethod
def load(cls, migrator: "ConfigMigrator") -> None:
"""Use the provided migrator to register the configuration migrations to be run."""
class Migrations(MigrationsBase):
"""Configuration migration steps to apply."""
@classmethod
def load(cls, migrator: "ConfigMigrator") -> None:
"""Define migrations to perform."""
##################
# 3.0.0 -> 4.0.0 #
##################
@migrator.register(from_version="3.0.0", to_version="4.0.0")
def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate a v3 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v3 config file.
Returns:
A dictionary of settings from a 4.0.0 config file.
"""
parsed_config_dict: dict[str, Any] = {}
for _category_name, category_dict in config_dict["InvokeAI"].items():
for k, v in category_dict.items():
# `outdir` was renamed to `outputs_dir` in v4
if k == "outdir":
parsed_config_dict["outputs_dir"] = v
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
# If if the incoming config has the default value, skip
continue
elif Path(v).name == "stable-diffusion":
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
else:
# Else we do not attempt to migrate this setting
parsed_config_dict["legacy_conf_dir"] = v
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
parsed_config_dict[k] = v
return parsed_config_dict
##################
# 4.0.0 -> 4.0.1 #
##################
@migrator.register(from_version="4.0.0", to_version="4.0.1")
def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate v4.0.0 config dictionary to v4.0.1.
Args:
config_dict: A dictionary of settings from a v4.0.0 config file.
Returns:
A dictionary of settings from a v4.0.1 config file
"""
parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
Returns:
The migrated v4.0.0 config dict.
"""
migrated_config: AppConfigDict = {}
for _category_name, category_dict in original_config["InvokeAI"].items():
for k, v in category_dict.items():
# `outdir` was renamed to `outputs_dir` in v4
if k == "outdir":
migrated_config["outputs_dir"] = v
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
migrated_config["ram"] = v
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
migrated_config["vram"] = v
if k == "conf_path":
migrated_config["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
# If if the incoming config has the default value, skip
continue
elif Path(v).name == "stable-diffusion":
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
migrated_config["legacy_conf_dir"] = str(Path(v).parent)
else:
parsed_config_dict[k] = v
return parsed_config_dict
# Else we do not attempt to migrate this setting
migrated_config["legacy_conf_dir"] = v
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
migrated_config[k] = v
migrated_config["schema_version"] = "4.0.0"
return migrated_config
config_migration_1 = MigrationEntry(
from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migrate_v300_to_v400
)
def migrate_v400_to_v401(original_config: AppConfigDict) -> AppConfigDict:
"""Migrate a v4.0.0 config dict to v4.0.1.
Changes in this migration:
- `precision: "autocast"` was removed, fall back to "auto"
Args:
config_dict: The v4.0.0 config dict to migrate.
Returns:
The migrated v4.0.1 config dict.
"""
migrated_config: AppConfigDict = {}
for k, v in original_config.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
migrated_config["precision"] = "auto"
migrated_config["schema_version"] = "4.0.1"
return migrated_config
config_migration_2 = MigrationEntry(
from_version=Version("4.0.0"), to_version=Version("4.0.1"), function=migrate_v400_to_v401
)