mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(config): simplify config migrator logic
- Remove `Migrations` class - unnecessary complexity on top of `MigrationEntry` - Move common classes to `config_common` - Tidy docstrings, variable names
This commit is contained in:
parent
fc23b16a73
commit
6946a3871f
@ -12,6 +12,10 @@ from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import pydoc
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, TypeAlias
|
||||
|
||||
from packaging.version import Version
|
||||
|
||||
|
||||
class PagingArgumentParser(argparse.ArgumentParser):
|
||||
@ -23,3 +27,21 @@ class PagingArgumentParser(argparse.ArgumentParser):
|
||||
def print_help(self, file=None) -> None:
|
||||
text = self.format_help()
|
||||
pydoc.pager(text)
|
||||
|
||||
|
||||
AppConfigDict: TypeAlias = dict[str, Any]
|
||||
|
||||
MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MigrationEntry:
|
||||
"""Defines an individual migration."""
|
||||
|
||||
from_version: Version
|
||||
to_version: Version
|
||||
function: MigrationFunction
|
||||
|
||||
def __hash__(self) -> int:
|
||||
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
|
||||
return hash((self.from_version, self.to_version))
|
||||
|
@ -6,60 +6,38 @@ Utility class for migrating among versions of the InvokeAI app config schema.
|
||||
|
||||
import locale
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from copy import deepcopy
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, TypeAlias
|
||||
|
||||
import yaml
|
||||
from packaging.version import Version
|
||||
|
||||
import invokeai.configs as model_configs
|
||||
from invokeai.app.services.config.config_common import AppConfigDict, MigrationEntry
|
||||
from invokeai.app.services.config.migrations import config_migration_1, config_migration_2
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
from .config_default import CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, URLRegexTokenPair
|
||||
from .migrations import AppConfigDict, Migrations, MigrationsBase
|
||||
|
||||
MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
|
||||
|
||||
|
||||
@dataclass
|
||||
class MigrationEntry:
|
||||
"""Defines an individual migration."""
|
||||
|
||||
from_version: Version
|
||||
to_version: Version
|
||||
function: MigrationFunction
|
||||
|
||||
|
||||
class ConfigMigrator:
|
||||
"""This class allows migrators to register their input and output versions."""
|
||||
|
||||
def __init__(self, migrations: type[MigrationsBase]) -> None:
|
||||
self._migrations: List[MigrationEntry] = []
|
||||
migrations.load(self)
|
||||
def __init__(self) -> None:
|
||||
self._migrations: set[MigrationEntry] = set()
|
||||
|
||||
def register(
|
||||
self,
|
||||
from_version: str,
|
||||
to_version: str,
|
||||
) -> Callable[[MigrationFunction], MigrationFunction]:
|
||||
"""Define a decorator which registers the migration between two versions."""
|
||||
|
||||
def decorator(function: MigrationFunction) -> MigrationFunction:
|
||||
if any((from_version == m.from_version) or (to_version == m.to_version) for m in self._migrations):
|
||||
raise ValueError(
|
||||
f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered."
|
||||
)
|
||||
self._migrations.append(
|
||||
MigrationEntry(from_version=Version(from_version), to_version=Version(to_version), function=function)
|
||||
def register(self, migration: MigrationEntry) -> None:
|
||||
migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations)
|
||||
migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations)
|
||||
if migration_from_already_registered or migration_to_already_registered:
|
||||
raise ValueError(
|
||||
f"A migration from {migration.from_version} or to {migration.to_version} has already been registered."
|
||||
)
|
||||
return function
|
||||
|
||||
return decorator
|
||||
self._migrations.add(migration)
|
||||
|
||||
@staticmethod
|
||||
def _check_for_discontinuities(migrations: List[MigrationEntry]) -> None:
|
||||
def _check_for_discontinuities(migrations: list[MigrationEntry]) -> None:
|
||||
current_version = Version("3.0.0")
|
||||
for m in migrations:
|
||||
if current_version != m.from_version:
|
||||
@ -68,35 +46,38 @@ class ConfigMigrator:
|
||||
)
|
||||
current_version = m.to_version
|
||||
|
||||
def run_migrations(self, config_dict: AppConfigDict) -> AppConfigDict:
|
||||
def run_migrations(self, original_config: AppConfigDict) -> AppConfigDict:
|
||||
"""
|
||||
Use the registered migration steps to bring config up to latest version.
|
||||
Use the registered migrations to bring config up to latest version.
|
||||
|
||||
:param config: The original configuration.
|
||||
:return: The new configuration, lifted up to the latest version.
|
||||
Args:
|
||||
original_config: The original configuration.
|
||||
|
||||
As a side effect, the new configuration will be written to disk.
|
||||
If an inconsistency in the registered migration steps' `from_version`
|
||||
and `to_version` parameters are identified, this will raise a
|
||||
ValueError exception.
|
||||
Returns:
|
||||
The new configuration, lifted up to the latest version.
|
||||
"""
|
||||
# Sort migrations by version number and raise a ValueError if
|
||||
# any version range overlaps are detected.
|
||||
|
||||
# Sort migrations by version number and raise a ValueError if any version range overlaps are detected.
|
||||
sorted_migrations = sorted(self._migrations, key=lambda x: x.from_version)
|
||||
self._check_for_discontinuities(sorted_migrations)
|
||||
|
||||
if "InvokeAI" in config_dict:
|
||||
# Do not mutate the incoming dict - we don't know who else may be using it
|
||||
migrated_config = deepcopy(original_config)
|
||||
|
||||
# v3.0.0 configs did not have "schema_version", but did have "InvokeAI"
|
||||
if "InvokeAI" in migrated_config:
|
||||
version = Version("3.0.0")
|
||||
else:
|
||||
version = Version(config_dict["schema_version"])
|
||||
version = Version(migrated_config["schema_version"])
|
||||
|
||||
for migration in sorted_migrations:
|
||||
if version == migration.from_version and version < migration.to_version:
|
||||
config_dict = migration.function(config_dict)
|
||||
if version == migration.from_version:
|
||||
migrated_config = migration.function(migrated_config)
|
||||
version = migration.to_version
|
||||
|
||||
config_dict["schema_version"] = str(version)
|
||||
return config_dict
|
||||
# We must end on the latest version
|
||||
assert migrated_config["schema_version"] == str(sorted_migrations[-1].to_version)
|
||||
return migrated_config
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
@ -165,14 +146,16 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
"""
|
||||
assert config_path.suffix == ".yaml"
|
||||
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||
loaded_config_dict = yaml.safe_load(file)
|
||||
loaded_config_dict: AppConfigDict = yaml.safe_load(file)
|
||||
|
||||
assert isinstance(loaded_config_dict, dict)
|
||||
|
||||
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
||||
try:
|
||||
migrator = ConfigMigrator(Migrations)
|
||||
migrated_config_dict = migrator.run_migrations(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
|
||||
migrator = ConfigMigrator()
|
||||
migrator.register(config_migration_1)
|
||||
migrator.register(config_migration_2)
|
||||
migrated_config_dict = migrator.run_migrations(loaded_config_dict)
|
||||
except Exception as e:
|
||||
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
||||
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
|
||||
|
@ -10,99 +10,90 @@ To define a new migration, add a migration function to
|
||||
Migrations.load_migrations() following the existing examples.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias
|
||||
|
||||
from packaging.version import Version
|
||||
|
||||
from invokeai.app.services.config.config_common import AppConfigDict, MigrationEntry
|
||||
|
||||
from .config_default import InvokeAIAppConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config_migrate import ConfigMigrator
|
||||
|
||||
AppConfigDict: TypeAlias = dict[str, Any]
|
||||
def migrate_v300_to_v400(original_config: AppConfigDict) -> AppConfigDict:
|
||||
"""Migrate a v3.0.0 config dict to v4.0.0.
|
||||
|
||||
Changes in this migration:
|
||||
- `outdir` was renamed to `outputs_dir`
|
||||
- `max_cache_size` was renamed to `ram`
|
||||
- `max_vram_cache_size` was renamed to `vram`
|
||||
- `conf_path`, which pointed to the old `models.yaml`, was removed - but if need to stash it to migrate the entries
|
||||
to the database
|
||||
- `legacy_conf_dir` was changed from a path relative to the app root, to a path relative to $INVOKEAI_ROOT/configs
|
||||
|
||||
class MigrationsBase(ABC):
|
||||
"""Define the config file migration steps to apply, abstract base class."""
|
||||
Args:
|
||||
config_dict: The v3.0.0 config dict to migrate.
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load(cls, migrator: "ConfigMigrator") -> None:
|
||||
"""Use the provided migrator to register the configuration migrations to be run."""
|
||||
|
||||
|
||||
class Migrations(MigrationsBase):
|
||||
"""Configuration migration steps to apply."""
|
||||
|
||||
@classmethod
|
||||
def load(cls, migrator: "ConfigMigrator") -> None:
|
||||
"""Define migrations to perform."""
|
||||
|
||||
##################
|
||||
# 3.0.0 -> 4.0.0 #
|
||||
##################
|
||||
@migrator.register(from_version="3.0.0", to_version="4.0.0")
|
||||
def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Migrate a v3 config dictionary to a current config object.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v3 config file.
|
||||
|
||||
Returns:
|
||||
A dictionary of settings from a 4.0.0 config file.
|
||||
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for _category_name, category_dict in config_dict["InvokeAI"].items():
|
||||
for k, v in category_dict.items():
|
||||
# `outdir` was renamed to `outputs_dir` in v4
|
||||
if k == "outdir":
|
||||
parsed_config_dict["outputs_dir"] = v
|
||||
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
||||
if k == "max_cache_size" and "ram" not in category_dict:
|
||||
parsed_config_dict["ram"] = v
|
||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||
parsed_config_dict["vram"] = v
|
||||
# autocast was removed in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
if k == "conf_path":
|
||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
||||
if k == "legacy_conf_dir":
|
||||
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
|
||||
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
|
||||
# If if the incoming config has the default value, skip
|
||||
continue
|
||||
elif Path(v).name == "stable-diffusion":
|
||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
|
||||
else:
|
||||
# Else we do not attempt to migrate this setting
|
||||
parsed_config_dict["legacy_conf_dir"] = v
|
||||
elif k in InvokeAIAppConfig.model_fields:
|
||||
# skip unknown fields
|
||||
parsed_config_dict[k] = v
|
||||
return parsed_config_dict
|
||||
|
||||
##################
|
||||
# 4.0.0 -> 4.0.1 #
|
||||
##################
|
||||
@migrator.register(from_version="4.0.0", to_version="4.0.1")
|
||||
def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Migrate v4.0.0 config dictionary to v4.0.1.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v4.0.0 config file.
|
||||
|
||||
Returns:
|
||||
A dictionary of settings from a v4.0.1 config file
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for k, v in config_dict.items():
|
||||
# autocast was removed from precision in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
Returns:
|
||||
The migrated v4.0.0 config dict.
|
||||
"""
|
||||
migrated_config: AppConfigDict = {}
|
||||
for _category_name, category_dict in original_config["InvokeAI"].items():
|
||||
for k, v in category_dict.items():
|
||||
# `outdir` was renamed to `outputs_dir` in v4
|
||||
if k == "outdir":
|
||||
migrated_config["outputs_dir"] = v
|
||||
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
||||
if k == "max_cache_size" and "ram" not in category_dict:
|
||||
migrated_config["ram"] = v
|
||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||
migrated_config["vram"] = v
|
||||
if k == "conf_path":
|
||||
migrated_config["legacy_models_yaml_path"] = v
|
||||
if k == "legacy_conf_dir":
|
||||
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
|
||||
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
|
||||
# If if the incoming config has the default value, skip
|
||||
continue
|
||||
elif Path(v).name == "stable-diffusion":
|
||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||
migrated_config["legacy_conf_dir"] = str(Path(v).parent)
|
||||
else:
|
||||
parsed_config_dict[k] = v
|
||||
return parsed_config_dict
|
||||
# Else we do not attempt to migrate this setting
|
||||
migrated_config["legacy_conf_dir"] = v
|
||||
elif k in InvokeAIAppConfig.model_fields:
|
||||
# skip unknown fields
|
||||
migrated_config[k] = v
|
||||
migrated_config["schema_version"] = "4.0.0"
|
||||
return migrated_config
|
||||
|
||||
|
||||
config_migration_1 = MigrationEntry(
|
||||
from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migrate_v300_to_v400
|
||||
)
|
||||
|
||||
|
||||
def migrate_v400_to_v401(original_config: AppConfigDict) -> AppConfigDict:
|
||||
"""Migrate a v4.0.0 config dict to v4.0.1.
|
||||
|
||||
Changes in this migration:
|
||||
- `precision: "autocast"` was removed, fall back to "auto"
|
||||
|
||||
Args:
|
||||
config_dict: The v4.0.0 config dict to migrate.
|
||||
|
||||
Returns:
|
||||
The migrated v4.0.1 config dict.
|
||||
"""
|
||||
migrated_config: AppConfigDict = {}
|
||||
for k, v in original_config.items():
|
||||
# autocast was removed from precision in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
migrated_config["precision"] = "auto"
|
||||
migrated_config["schema_version"] = "4.0.1"
|
||||
return migrated_config
|
||||
|
||||
|
||||
config_migration_2 = MigrationEntry(
|
||||
from_version=Version("4.0.0"), to_version=Version("4.0.1"), function=migrate_v400_to_v401
|
||||
)
|
||||
|
Loading…
Reference in New Issue
Block a user