add InvokeAIAppConfig schema migration system

This commit is contained in:
Lincoln Stein 2024-04-18 21:33:54 -04:00
parent a35386f24c
commit 6ad1948a44
3 changed files with 212 additions and 87 deletions

View File

@ -20,6 +20,8 @@ import invokeai.configs as model_configs
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
from .config_migrate import ConfigMigrator
INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db")
LEGACY_INIT_FILE = Path("invokeai.init")
@ -348,75 +350,6 @@ class DefaultInvokeAIAppConfig(InvokeAIAppConfig):
return (init_settings,)
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate a v3 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v3 config file.
Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
"""
parsed_config_dict: dict[str, Any] = {}
for _category_name, category_dict in config_dict["InvokeAI"].items():
for k, v in category_dict.items():
# `outdir` was renamed to `outputs_dir` in v4
if k == "outdir":
parsed_config_dict["outputs_dir"] = v
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
# If if the incoming config has the default value, skip
continue
elif Path(v).name == "stable-diffusion":
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
else:
# Else we do not attempt to migrate this setting
parsed_config_dict["legacy_conf_dir"] = v
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
parsed_config_dict[k] = v
# When migrating the config file, we should not include currently-set environment variables.
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
"""Migrate v4.0.0 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v4.0.0 config file.
Returns:
An instance of `InvokeAIAppConfig` with the migrated settings.
"""
parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
else:
parsed_config_dict[k] = v
if k == "schema_version":
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
return config
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
@ -432,29 +365,20 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
assert isinstance(loaded_config_dict, dict)
if "InvokeAI" in loaded_config_dict:
# This is a v3 config file, attempt to migrate it
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try:
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
migrated_config = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
migrated_config.write_file(config_path)
return migrated_config
if loaded_config_dict["schema_version"] == "4.0.0":
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
loaded_config_dict.write_file(config_path)
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try:
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
migrated_config_dict = ConfigMigrator.migrate(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
# Attempt to load as a v4 config file
try:
# Meta is not included in the model fields, so we need to validate it separately
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
@ -504,6 +428,7 @@ def get_config() -> InvokeAIAppConfig:
if config.config_file_path.exists():
config_from_file = load_and_migrate_config(config.config_file_path)
config_from_file.write_file(config.config_file_path)
# Clobbering here will overwrite any settings that were set via environment variables
config.update_config(config_from_file, clobber=False)
else:
@ -512,3 +437,73 @@ def get_config() -> InvokeAIAppConfig:
default_config.write_file(config.config_file_path, as_example=False)
return config
####################################################
# VERSION MIGRATIONS
####################################################
@ConfigMigrator.register(from_version="0.0.0", to_version="4.0.0")
def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate a v3 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v3 config file.
Returns:
A dictionary of settings from a 4.0.0 config file.
"""
parsed_config_dict: dict[str, Any] = {}
for _category_name, category_dict in config_dict["InvokeAI"].items():
for k, v in category_dict.items():
# `outdir` was renamed to `outputs_dir` in v4
if k == "outdir":
parsed_config_dict["outputs_dir"] = v
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
# If if the incoming config has the default value, skip
continue
elif Path(v).name == "stable-diffusion":
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
else:
# Else we do not attempt to migrate this setting
parsed_config_dict["legacy_conf_dir"] = v
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
parsed_config_dict[k] = v
return parsed_config_dict
@ConfigMigrator.register(from_version="4.0.0", to_version="4.0.1")
def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate v4.0.0 config dictionary to v4.0.1.
Args:
config_dict: A dictionary of settings from a v4.0.0 config file.
Returns:
A dictionary of settings from a v4.0.1 config file
"""
parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
else:
parsed_config_dict[k] = v
return parsed_config_dict

View File

@ -0,0 +1,129 @@
# Copyright 2024 Lincoln D. Stein and the InvokeAI Development Team
"""
Utility class for migrating among versions of the InvokeAI app config schema.
"""
from abc import ABC, abstractmethod
from typing import TYPE_CHECKING, Any, Callable, List, TypeVar
from pydantic import BaseModel, ConfigDict, field_validator
from version_parser import Version
if TYPE_CHECKING:
pass
AppConfigDict = TypeVar("AppConfigDict", bound=dict[str, Any])
class AppVersion(Version):
"""Stringlike object that sorts like a version."""
def __hash__(self) -> int: # noqa D105
return hash(str(self))
def __repr__(self) -> str: # noqa D105
return f"AppVersion('{str(self)}')"
class ConfigMigratorBase(ABC):
"""This class allows migrators to register their input and output versions."""
@classmethod
@abstractmethod
def register(
cls, from_version: AppVersion, to_version: AppVersion
) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]:
"""Define a decorator which registers the migration between two versions."""
@classmethod
@abstractmethod
def migrate(cls, config: AppConfigDict) -> AppConfigDict:
"""
Use the registered migration steps to bring config up to latest version.
:param config: The original configuration.
:return: The new configuration, lifted up to the latest version.
As a side effect, the new configuration will be written to disk.
"""
class MigrationEntry(BaseModel):
"""Defines an individual migration."""
model_config = ConfigDict(arbitrary_types_allowed=True)
from_version: AppVersion
to_version: AppVersion
function: Callable[[AppConfigDict], AppConfigDict]
@field_validator("from_version", "to_version", mode="before")
@classmethod
def _string_to_version(cls, v: str | AppVersion) -> AppVersion: # noqa D102
if isinstance(v, str):
return AppVersion(v)
else:
return v
class ConfigMigrator(ConfigMigratorBase):
"""This class allows migrators to register their input and output versions."""
_migrations: List[MigrationEntry] = []
@classmethod
def register(
cls,
from_version: AppVersion | str,
to_version: AppVersion | str,
) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]:
"""Define a decorator which registers the migration between two versions."""
def decorator(function: Callable[[AppConfigDict], AppConfigDict]) -> Callable[[AppConfigDict], AppConfigDict]:
if from_version in cls._migrations:
raise ValueError(
f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered."
)
cls._migrations.append(MigrationEntry(from_version=from_version, to_version=to_version, function=function))
return function
return decorator
@staticmethod
def _check_for_overlaps(migrations: List[MigrationEntry]) -> None:
current_version = AppVersion("0.0.0")
for m in migrations:
if current_version > m.from_version:
raise ValueError(f"Version range overlap detected while processing function {m.function.__name__}")
@classmethod
def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict:
"""
Use the registered migration steps to bring config up to latest version.
:param config: The original configuration.
:return: The new configuration, lifted up to the latest version.
As a side effect, the new configuration will be written to disk.
If an inconsistency in the registered migration steps' `from_version`
and `to_version` parameters are identified, this will raise a
ValueError exception.
"""
# Sort migrations by version number and raise a ValueError if
# any version range overlaps are detected. Discontinuities are ok
sorted_migrations = sorted(cls._migrations, key=lambda x: x.from_version)
cls._check_for_overlaps(sorted_migrations)
if "InvokeAI" in config_dict:
version = AppVersion("3.0.0")
else:
version = AppVersion(config_dict["schema_version"])
for migration in sorted_migrations:
if version >= migration.from_version and version < migration.to_version:
config_dict = migration.function(config_dict)
version = migration.to_version
config_dict["schema_version"] = str(version)
return config_dict

View File

@ -90,6 +90,7 @@ dependencies = [
"semver~=3.0.1",
"send2trash",
"test-tube~=0.7.5",
"version-parser",
"windows-curses; sys_platform=='win32'",
]