mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
feat(config): simplify config migrator logic
- Remove `Migrations` class - unnecessary complexity on top of `MigrationEntry` - Move common classes to `config_common` - Tidy docstrings, variable names
This commit is contained in:
parent
fc23b16a73
commit
6946a3871f
@ -12,6 +12,10 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import argparse
|
import argparse
|
||||||
import pydoc
|
import pydoc
|
||||||
|
from dataclasses import dataclass
|
||||||
|
from typing import Any, Callable, TypeAlias
|
||||||
|
|
||||||
|
from packaging.version import Version
|
||||||
|
|
||||||
|
|
||||||
class PagingArgumentParser(argparse.ArgumentParser):
|
class PagingArgumentParser(argparse.ArgumentParser):
|
||||||
@ -23,3 +27,21 @@ class PagingArgumentParser(argparse.ArgumentParser):
|
|||||||
def print_help(self, file=None) -> None:
|
def print_help(self, file=None) -> None:
|
||||||
text = self.format_help()
|
text = self.format_help()
|
||||||
pydoc.pager(text)
|
pydoc.pager(text)
|
||||||
|
|
||||||
|
|
||||||
|
AppConfigDict: TypeAlias = dict[str, Any]
|
||||||
|
|
||||||
|
MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
|
||||||
|
|
||||||
|
|
||||||
|
@dataclass
|
||||||
|
class MigrationEntry:
|
||||||
|
"""Defines an individual migration."""
|
||||||
|
|
||||||
|
from_version: Version
|
||||||
|
to_version: Version
|
||||||
|
function: MigrationFunction
|
||||||
|
|
||||||
|
def __hash__(self) -> int:
|
||||||
|
# Callables are not hashable, so we need to implement our own __hash__ function to use this class in a set.
|
||||||
|
return hash((self.from_version, self.to_version))
|
||||||
|
@ -6,60 +6,38 @@ Utility class for migrating among versions of the InvokeAI app config schema.
|
|||||||
|
|
||||||
import locale
|
import locale
|
||||||
import shutil
|
import shutil
|
||||||
from dataclasses import dataclass
|
from copy import deepcopy
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, List, TypeAlias
|
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
|
||||||
import invokeai.configs as model_configs
|
import invokeai.configs as model_configs
|
||||||
|
from invokeai.app.services.config.config_common import AppConfigDict, MigrationEntry
|
||||||
|
from invokeai.app.services.config.migrations import config_migration_1, config_migration_2
|
||||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||||
|
|
||||||
from .config_default import CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, URLRegexTokenPair
|
from .config_default import CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, URLRegexTokenPair
|
||||||
from .migrations import AppConfigDict, Migrations, MigrationsBase
|
|
||||||
|
|
||||||
MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
|
||||||
class MigrationEntry:
|
|
||||||
"""Defines an individual migration."""
|
|
||||||
|
|
||||||
from_version: Version
|
|
||||||
to_version: Version
|
|
||||||
function: MigrationFunction
|
|
||||||
|
|
||||||
|
|
||||||
class ConfigMigrator:
|
class ConfigMigrator:
|
||||||
"""This class allows migrators to register their input and output versions."""
|
"""This class allows migrators to register their input and output versions."""
|
||||||
|
|
||||||
def __init__(self, migrations: type[MigrationsBase]) -> None:
|
def __init__(self) -> None:
|
||||||
self._migrations: List[MigrationEntry] = []
|
self._migrations: set[MigrationEntry] = set()
|
||||||
migrations.load(self)
|
|
||||||
|
|
||||||
def register(
|
def register(self, migration: MigrationEntry) -> None:
|
||||||
self,
|
migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations)
|
||||||
from_version: str,
|
migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations)
|
||||||
to_version: str,
|
if migration_from_already_registered or migration_to_already_registered:
|
||||||
) -> Callable[[MigrationFunction], MigrationFunction]:
|
|
||||||
"""Define a decorator which registers the migration between two versions."""
|
|
||||||
|
|
||||||
def decorator(function: MigrationFunction) -> MigrationFunction:
|
|
||||||
if any((from_version == m.from_version) or (to_version == m.to_version) for m in self._migrations):
|
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered."
|
f"A migration from {migration.from_version} or to {migration.to_version} has already been registered."
|
||||||
)
|
)
|
||||||
self._migrations.append(
|
self._migrations.add(migration)
|
||||||
MigrationEntry(from_version=Version(from_version), to_version=Version(to_version), function=function)
|
|
||||||
)
|
|
||||||
return function
|
|
||||||
|
|
||||||
return decorator
|
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def _check_for_discontinuities(migrations: List[MigrationEntry]) -> None:
|
def _check_for_discontinuities(migrations: list[MigrationEntry]) -> None:
|
||||||
current_version = Version("3.0.0")
|
current_version = Version("3.0.0")
|
||||||
for m in migrations:
|
for m in migrations:
|
||||||
if current_version != m.from_version:
|
if current_version != m.from_version:
|
||||||
@ -68,35 +46,38 @@ class ConfigMigrator:
|
|||||||
)
|
)
|
||||||
current_version = m.to_version
|
current_version = m.to_version
|
||||||
|
|
||||||
def run_migrations(self, config_dict: AppConfigDict) -> AppConfigDict:
|
def run_migrations(self, original_config: AppConfigDict) -> AppConfigDict:
|
||||||
"""
|
"""
|
||||||
Use the registered migration steps to bring config up to latest version.
|
Use the registered migrations to bring config up to latest version.
|
||||||
|
|
||||||
:param config: The original configuration.
|
Args:
|
||||||
:return: The new configuration, lifted up to the latest version.
|
original_config: The original configuration.
|
||||||
|
|
||||||
As a side effect, the new configuration will be written to disk.
|
Returns:
|
||||||
If an inconsistency in the registered migration steps' `from_version`
|
The new configuration, lifted up to the latest version.
|
||||||
and `to_version` parameters are identified, this will raise a
|
|
||||||
ValueError exception.
|
|
||||||
"""
|
"""
|
||||||
# Sort migrations by version number and raise a ValueError if
|
|
||||||
# any version range overlaps are detected.
|
# Sort migrations by version number and raise a ValueError if any version range overlaps are detected.
|
||||||
sorted_migrations = sorted(self._migrations, key=lambda x: x.from_version)
|
sorted_migrations = sorted(self._migrations, key=lambda x: x.from_version)
|
||||||
self._check_for_discontinuities(sorted_migrations)
|
self._check_for_discontinuities(sorted_migrations)
|
||||||
|
|
||||||
if "InvokeAI" in config_dict:
|
# Do not mutate the incoming dict - we don't know who else may be using it
|
||||||
|
migrated_config = deepcopy(original_config)
|
||||||
|
|
||||||
|
# v3.0.0 configs did not have "schema_version", but did have "InvokeAI"
|
||||||
|
if "InvokeAI" in migrated_config:
|
||||||
version = Version("3.0.0")
|
version = Version("3.0.0")
|
||||||
else:
|
else:
|
||||||
version = Version(config_dict["schema_version"])
|
version = Version(migrated_config["schema_version"])
|
||||||
|
|
||||||
for migration in sorted_migrations:
|
for migration in sorted_migrations:
|
||||||
if version == migration.from_version and version < migration.to_version:
|
if version == migration.from_version:
|
||||||
config_dict = migration.function(config_dict)
|
migrated_config = migration.function(migrated_config)
|
||||||
version = migration.to_version
|
version = migration.to_version
|
||||||
|
|
||||||
config_dict["schema_version"] = str(version)
|
# We must end on the latest version
|
||||||
return config_dict
|
assert migrated_config["schema_version"] == str(sorted_migrations[-1].to_version)
|
||||||
|
return migrated_config
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
@lru_cache(maxsize=1)
|
||||||
@ -165,14 +146,16 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
|||||||
"""
|
"""
|
||||||
assert config_path.suffix == ".yaml"
|
assert config_path.suffix == ".yaml"
|
||||||
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||||
loaded_config_dict = yaml.safe_load(file)
|
loaded_config_dict: AppConfigDict = yaml.safe_load(file)
|
||||||
|
|
||||||
assert isinstance(loaded_config_dict, dict)
|
assert isinstance(loaded_config_dict, dict)
|
||||||
|
|
||||||
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
||||||
try:
|
try:
|
||||||
migrator = ConfigMigrator(Migrations)
|
migrator = ConfigMigrator()
|
||||||
migrated_config_dict = migrator.run_migrations(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
|
migrator.register(config_migration_1)
|
||||||
|
migrator.register(config_migration_2)
|
||||||
|
migrated_config_dict = migrator.run_migrations(loaded_config_dict)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
||||||
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
|
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
|
||||||
|
@ -10,65 +10,46 @@ To define a new migration, add a migration function to
|
|||||||
Migrations.load_migrations() following the existing examples.
|
Migrations.load_migrations() following the existing examples.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
from abc import ABC, abstractmethod
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import TYPE_CHECKING, Any, TypeAlias
|
|
||||||
|
from packaging.version import Version
|
||||||
|
|
||||||
|
from invokeai.app.services.config.config_common import AppConfigDict, MigrationEntry
|
||||||
|
|
||||||
from .config_default import InvokeAIAppConfig
|
from .config_default import InvokeAIAppConfig
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
|
||||||
from .config_migrate import ConfigMigrator
|
|
||||||
|
|
||||||
AppConfigDict: TypeAlias = dict[str, Any]
|
def migrate_v300_to_v400(original_config: AppConfigDict) -> AppConfigDict:
|
||||||
|
"""Migrate a v3.0.0 config dict to v4.0.0.
|
||||||
|
|
||||||
|
Changes in this migration:
|
||||||
class MigrationsBase(ABC):
|
- `outdir` was renamed to `outputs_dir`
|
||||||
"""Define the config file migration steps to apply, abstract base class."""
|
- `max_cache_size` was renamed to `ram`
|
||||||
|
- `max_vram_cache_size` was renamed to `vram`
|
||||||
@classmethod
|
- `conf_path`, which pointed to the old `models.yaml`, was removed - but if need to stash it to migrate the entries
|
||||||
@abstractmethod
|
to the database
|
||||||
def load(cls, migrator: "ConfigMigrator") -> None:
|
- `legacy_conf_dir` was changed from a path relative to the app root, to a path relative to $INVOKEAI_ROOT/configs
|
||||||
"""Use the provided migrator to register the configuration migrations to be run."""
|
|
||||||
|
|
||||||
|
|
||||||
class Migrations(MigrationsBase):
|
|
||||||
"""Configuration migration steps to apply."""
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def load(cls, migrator: "ConfigMigrator") -> None:
|
|
||||||
"""Define migrations to perform."""
|
|
||||||
|
|
||||||
##################
|
|
||||||
# 3.0.0 -> 4.0.0 #
|
|
||||||
##################
|
|
||||||
@migrator.register(from_version="3.0.0", to_version="4.0.0")
|
|
||||||
def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Migrate a v3 config dictionary to a current config object.
|
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_dict: A dictionary of settings from a v3 config file.
|
config_dict: The v3.0.0 config dict to migrate.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary of settings from a 4.0.0 config file.
|
The migrated v4.0.0 config dict.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
parsed_config_dict: dict[str, Any] = {}
|
migrated_config: AppConfigDict = {}
|
||||||
for _category_name, category_dict in config_dict["InvokeAI"].items():
|
for _category_name, category_dict in original_config["InvokeAI"].items():
|
||||||
for k, v in category_dict.items():
|
for k, v in category_dict.items():
|
||||||
# `outdir` was renamed to `outputs_dir` in v4
|
# `outdir` was renamed to `outputs_dir` in v4
|
||||||
if k == "outdir":
|
if k == "outdir":
|
||||||
parsed_config_dict["outputs_dir"] = v
|
migrated_config["outputs_dir"] = v
|
||||||
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
||||||
if k == "max_cache_size" and "ram" not in category_dict:
|
if k == "max_cache_size" and "ram" not in category_dict:
|
||||||
parsed_config_dict["ram"] = v
|
migrated_config["ram"] = v
|
||||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||||
parsed_config_dict["vram"] = v
|
migrated_config["vram"] = v
|
||||||
# autocast was removed in v4.0.1
|
|
||||||
if k == "precision" and v == "autocast":
|
|
||||||
parsed_config_dict["precision"] = "auto"
|
|
||||||
if k == "conf_path":
|
if k == "conf_path":
|
||||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
migrated_config["legacy_models_yaml_path"] = v
|
||||||
if k == "legacy_conf_dir":
|
if k == "legacy_conf_dir":
|
||||||
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
|
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
|
||||||
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
|
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
|
||||||
@ -76,33 +57,43 @@ class Migrations(MigrationsBase):
|
|||||||
continue
|
continue
|
||||||
elif Path(v).name == "stable-diffusion":
|
elif Path(v).name == "stable-diffusion":
|
||||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||||
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
|
migrated_config["legacy_conf_dir"] = str(Path(v).parent)
|
||||||
else:
|
else:
|
||||||
# Else we do not attempt to migrate this setting
|
# Else we do not attempt to migrate this setting
|
||||||
parsed_config_dict["legacy_conf_dir"] = v
|
migrated_config["legacy_conf_dir"] = v
|
||||||
elif k in InvokeAIAppConfig.model_fields:
|
elif k in InvokeAIAppConfig.model_fields:
|
||||||
# skip unknown fields
|
# skip unknown fields
|
||||||
parsed_config_dict[k] = v
|
migrated_config[k] = v
|
||||||
return parsed_config_dict
|
migrated_config["schema_version"] = "4.0.0"
|
||||||
|
return migrated_config
|
||||||
|
|
||||||
##################
|
|
||||||
# 4.0.0 -> 4.0.1 #
|
config_migration_1 = MigrationEntry(
|
||||||
##################
|
from_version=Version("3.0.0"), to_version=Version("4.0.0"), function=migrate_v300_to_v400
|
||||||
@migrator.register(from_version="4.0.0", to_version="4.0.1")
|
)
|
||||||
def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Migrate v4.0.0 config dictionary to v4.0.1.
|
|
||||||
|
def migrate_v400_to_v401(original_config: AppConfigDict) -> AppConfigDict:
|
||||||
|
"""Migrate a v4.0.0 config dict to v4.0.1.
|
||||||
|
|
||||||
|
Changes in this migration:
|
||||||
|
- `precision: "autocast"` was removed, fall back to "auto"
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config_dict: A dictionary of settings from a v4.0.0 config file.
|
config_dict: The v4.0.0 config dict to migrate.
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary of settings from a v4.0.1 config file
|
The migrated v4.0.1 config dict.
|
||||||
"""
|
"""
|
||||||
parsed_config_dict: dict[str, Any] = {}
|
migrated_config: AppConfigDict = {}
|
||||||
for k, v in config_dict.items():
|
for k, v in original_config.items():
|
||||||
# autocast was removed from precision in v4.0.1
|
# autocast was removed from precision in v4.0.1
|
||||||
if k == "precision" and v == "autocast":
|
if k == "precision" and v == "autocast":
|
||||||
parsed_config_dict["precision"] = "auto"
|
migrated_config["precision"] = "auto"
|
||||||
else:
|
migrated_config["schema_version"] = "4.0.1"
|
||||||
parsed_config_dict[k] = v
|
return migrated_config
|
||||||
return parsed_config_dict
|
|
||||||
|
|
||||||
|
config_migration_2 = MigrationEntry(
|
||||||
|
from_version=Version("4.0.0"), to_version=Version("4.0.1"), function=migrate_v400_to_v401
|
||||||
|
)
|
||||||
|
Loading…
Reference in New Issue
Block a user