InvokeAI/invokeai/app/services/config/config_migrate.py
psychedelicious d487102904 fix(config): fix config _check_for_discontinuities
Need to sort the migrations first.
2024-05-14 16:55:22 +10:00

175 lines
7.5 KiB
Python

# Copyright 2024 Lincoln D. Stein and the InvokeAI Development Team
"""
Utility class for migrating among versions of the InvokeAI app config schema.
"""
import locale
import shutil
from copy import deepcopy
from functools import lru_cache
from pathlib import Path
from typing import Iterable
import yaml
from packaging.version import Version
import invokeai.configs as model_configs
from invokeai.app.services.config.config_common import AppConfigDict, ConfigMigration
from invokeai.app.services.config.migrations import config_migration_1, config_migration_2
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
from .config_default import CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, URLRegexTokenPair
class ConfigMigrator:
"""This class allows migrators to register their input and output versions."""
def __init__(self) -> None:
self._migrations: set[ConfigMigration] = set()
def register(self, migration: ConfigMigration) -> None:
migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations)
migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations)
if migration_from_already_registered or migration_to_already_registered:
raise ValueError(
f"A migration from {migration.from_version} or to {migration.to_version} has already been registered."
)
self._migrations.add(migration)
@staticmethod
def _check_for_discontinuities(migrations: Iterable[ConfigMigration]) -> None:
current_version = Version("3.0.0")
sorted_migrations = sorted(migrations, key=lambda x: x.from_version)
for m in sorted_migrations:
if current_version != m.from_version:
raise ValueError(
f"Migration functions are not continuous. Expected from_version={current_version} but got from_version={m.from_version}, for migration function {m.function.__name__}"
)
current_version = m.to_version
def run_migrations(self, original_config: AppConfigDict) -> AppConfigDict:
"""
Use the registered migrations to bring config up to latest version.
Args:
original_config: The original configuration.
Returns:
The new configuration, lifted up to the latest version.
"""
# Sort migrations by version number and raise a ValueError if any version range overlaps are detected.
sorted_migrations = sorted(self._migrations, key=lambda x: x.from_version)
self._check_for_discontinuities(sorted_migrations)
# Do not mutate the incoming dict - we don't know who else may be using it
migrated_config = deepcopy(original_config)
# v3.0.0 configs did not have "schema_version", but did have "InvokeAI"
if "InvokeAI" in migrated_config:
version = Version("3.0.0")
else:
version = Version(migrated_config["schema_version"])
for migration in sorted_migrations:
if version == migration.from_version:
migrated_config = migration.function(migrated_config)
version = migration.to_version
# We must end on the latest version
assert migrated_config["schema_version"] == str(sorted_migrations[-1].to_version)
return migrated_config
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
Args:
config_path: Path to the config file.
Returns:
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
"""
assert config_path.suffix == ".yaml"
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
loaded_config_dict: AppConfigDict = yaml.safe_load(file)
assert isinstance(loaded_config_dict, dict)
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try:
migrator = ConfigMigrator()
migrator.register(config_migration_1)
migrator.register(config_migration_2)
migrated_config_dict = migrator.run_migrations(loaded_config_dict)
except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
# Attempt to load as a v4 config file
try:
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
# TODO(psyche): This must must be in this file to avoid circular dependencies
@lru_cache(maxsize=1)
def get_config() -> InvokeAIAppConfig:
"""Get the global singleton app config.
When first called, this function:
- Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file.
- Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint.
- Sets the root dir, if provided via CLI args.
- Logs in to HF if there is no valid token already.
- Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers).
- Reads and merges in settings from the config file if it exists, else writes out a default config file.
On subsequent calls, the object is returned from the cache.
"""
# This object includes environment variables, as parsed by pydantic-settings
config = InvokeAIAppConfig()
args = InvokeAIArgs.args
# This flag serves as a proxy for whether the config was retrieved in the context of the full application or not.
# If it is False, we should just return a default config and not set the root, log in to HF, etc.
if not InvokeAIArgs.did_parse:
return config
# Set CLI args
if root := getattr(args, "root", None):
config._root = Path(root)
if config_file := getattr(args, "config_file", None):
config._config_file = Path(config_file)
# Create the example config file, with some extra example values provided
example_config = DefaultInvokeAIAppConfig()
example_config.remote_api_tokens = [
URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"),
URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"),
]
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
# Copy all legacy configs - We know `__path__[0]` is correct here
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
if config.config_file_path.exists():
config_from_file = load_and_migrate_config(config.config_file_path)
config_from_file.write_file(config.config_file_path)
# Clobbering here will overwrite any settings that were set via environment variables
config.update_config(config_from_file, clobber=False)
else:
# We should never write env vars to the config file
default_config = DefaultInvokeAIAppConfig()
default_config.write_file(config.config_file_path, as_example=False)
return config