mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
175 lines
7.5 KiB
Python
175 lines
7.5 KiB
Python
# Copyright 2024 Lincoln D. Stein and the InvokeAI Development Team
|
|
|
|
"""
|
|
Utility class for migrating among versions of the InvokeAI app config schema.
|
|
"""
|
|
|
|
import locale
|
|
import shutil
|
|
from copy import deepcopy
|
|
from functools import lru_cache
|
|
from pathlib import Path
|
|
from typing import Iterable
|
|
|
|
import yaml
|
|
from packaging.version import Version
|
|
|
|
import invokeai.configs as model_configs
|
|
from invokeai.app.services.config.config_common import AppConfigDict, ConfigMigration
|
|
from invokeai.app.services.config.migrations import config_migration_1, config_migration_2
|
|
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
|
|
|
from .config_default import CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, URLRegexTokenPair
|
|
|
|
|
|
class ConfigMigrator:
|
|
"""This class allows migrators to register their input and output versions."""
|
|
|
|
def __init__(self) -> None:
|
|
self._migrations: set[ConfigMigration] = set()
|
|
|
|
def register(self, migration: ConfigMigration) -> None:
|
|
migration_from_already_registered = any(m.from_version == migration.from_version for m in self._migrations)
|
|
migration_to_already_registered = any(m.to_version == migration.to_version for m in self._migrations)
|
|
if migration_from_already_registered or migration_to_already_registered:
|
|
raise ValueError(
|
|
f"A migration from {migration.from_version} or to {migration.to_version} has already been registered."
|
|
)
|
|
self._migrations.add(migration)
|
|
|
|
@staticmethod
|
|
def _check_for_discontinuities(migrations: Iterable[ConfigMigration]) -> None:
|
|
current_version = Version("3.0.0")
|
|
sorted_migrations = sorted(migrations, key=lambda x: x.from_version)
|
|
for m in sorted_migrations:
|
|
if current_version != m.from_version:
|
|
raise ValueError(
|
|
f"Migration functions are not continuous. Expected from_version={current_version} but got from_version={m.from_version}, for migration function {m.function.__name__}"
|
|
)
|
|
current_version = m.to_version
|
|
|
|
def run_migrations(self, original_config: AppConfigDict) -> AppConfigDict:
|
|
"""
|
|
Use the registered migrations to bring config up to latest version.
|
|
|
|
Args:
|
|
original_config: The original configuration.
|
|
|
|
Returns:
|
|
The new configuration, lifted up to the latest version.
|
|
"""
|
|
|
|
# Sort migrations by version number and raise a ValueError if any version range overlaps are detected.
|
|
sorted_migrations = sorted(self._migrations, key=lambda x: x.from_version)
|
|
self._check_for_discontinuities(sorted_migrations)
|
|
|
|
# Do not mutate the incoming dict - we don't know who else may be using it
|
|
migrated_config = deepcopy(original_config)
|
|
|
|
# v3.0.0 configs did not have "schema_version", but did have "InvokeAI"
|
|
if "InvokeAI" in migrated_config:
|
|
version = Version("3.0.0")
|
|
else:
|
|
version = Version(migrated_config["schema_version"])
|
|
|
|
for migration in sorted_migrations:
|
|
if version == migration.from_version:
|
|
migrated_config = migration.function(migrated_config)
|
|
version = migration.to_version
|
|
|
|
# We must end on the latest version
|
|
assert migrated_config["schema_version"] == str(sorted_migrations[-1].to_version)
|
|
return migrated_config
|
|
|
|
|
|
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
|
"""Load and migrate a config file to the latest version.
|
|
|
|
Args:
|
|
config_path: Path to the config file.
|
|
|
|
Returns:
|
|
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
|
|
"""
|
|
assert config_path.suffix == ".yaml"
|
|
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
|
loaded_config_dict: AppConfigDict = yaml.safe_load(file)
|
|
|
|
assert isinstance(loaded_config_dict, dict)
|
|
|
|
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
|
try:
|
|
migrator = ConfigMigrator()
|
|
migrator.register(config_migration_1)
|
|
migrator.register(config_migration_2)
|
|
migrated_config_dict = migrator.run_migrations(loaded_config_dict)
|
|
except Exception as e:
|
|
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
|
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
|
|
|
|
# Attempt to load as a v4 config file
|
|
try:
|
|
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
|
|
assert (
|
|
config.schema_version == CONFIG_SCHEMA_VERSION
|
|
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}"
|
|
return config
|
|
except Exception as e:
|
|
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
|
|
|
|
|
# TODO(psyche): This must must be in this file to avoid circular dependencies
|
|
@lru_cache(maxsize=1)
|
|
def get_config() -> InvokeAIAppConfig:
|
|
"""Get the global singleton app config.
|
|
|
|
When first called, this function:
|
|
- Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file.
|
|
- Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint.
|
|
- Sets the root dir, if provided via CLI args.
|
|
- Logs in to HF if there is no valid token already.
|
|
- Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers).
|
|
- Reads and merges in settings from the config file if it exists, else writes out a default config file.
|
|
|
|
On subsequent calls, the object is returned from the cache.
|
|
"""
|
|
# This object includes environment variables, as parsed by pydantic-settings
|
|
config = InvokeAIAppConfig()
|
|
|
|
args = InvokeAIArgs.args
|
|
|
|
# This flag serves as a proxy for whether the config was retrieved in the context of the full application or not.
|
|
# If it is False, we should just return a default config and not set the root, log in to HF, etc.
|
|
if not InvokeAIArgs.did_parse:
|
|
return config
|
|
|
|
# Set CLI args
|
|
if root := getattr(args, "root", None):
|
|
config._root = Path(root)
|
|
if config_file := getattr(args, "config_file", None):
|
|
config._config_file = Path(config_file)
|
|
|
|
# Create the example config file, with some extra example values provided
|
|
example_config = DefaultInvokeAIAppConfig()
|
|
example_config.remote_api_tokens = [
|
|
URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"),
|
|
URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"),
|
|
]
|
|
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
|
|
|
|
# Copy all legacy configs - We know `__path__[0]` is correct here
|
|
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
|
|
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
|
|
|
|
if config.config_file_path.exists():
|
|
config_from_file = load_and_migrate_config(config.config_file_path)
|
|
config_from_file.write_file(config.config_file_path)
|
|
# Clobbering here will overwrite any settings that were set via environment variables
|
|
config.update_config(config_from_file, clobber=False)
|
|
else:
|
|
# We should never write env vars to the config file
|
|
default_config = DefaultInvokeAIAppConfig()
|
|
default_config.write_file(config.config_file_path, as_example=False)
|
|
|
|
return config
|