mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add InvokeAIAppConfig schema migration system
This commit is contained in:
parent
a35386f24c
commit
6ad1948a44
@ -20,6 +20,8 @@ import invokeai.configs as model_configs
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
from .config_migrate import ConfigMigrator
|
||||
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||
@ -348,75 +350,6 @@ class DefaultInvokeAIAppConfig(InvokeAIAppConfig):
|
||||
return (init_settings,)
|
||||
|
||||
|
||||
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
"""Migrate a v3 config dictionary to a current config object.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v3 config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for _category_name, category_dict in config_dict["InvokeAI"].items():
|
||||
for k, v in category_dict.items():
|
||||
# `outdir` was renamed to `outputs_dir` in v4
|
||||
if k == "outdir":
|
||||
parsed_config_dict["outputs_dir"] = v
|
||||
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
||||
if k == "max_cache_size" and "ram" not in category_dict:
|
||||
parsed_config_dict["ram"] = v
|
||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||
parsed_config_dict["vram"] = v
|
||||
# autocast was removed in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
if k == "conf_path":
|
||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
||||
if k == "legacy_conf_dir":
|
||||
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
|
||||
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
|
||||
# If if the incoming config has the default value, skip
|
||||
continue
|
||||
elif Path(v).name == "stable-diffusion":
|
||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
|
||||
else:
|
||||
# Else we do not attempt to migrate this setting
|
||||
parsed_config_dict["legacy_conf_dir"] = v
|
||||
elif k in InvokeAIAppConfig.model_fields:
|
||||
# skip unknown fields
|
||||
parsed_config_dict[k] = v
|
||||
# When migrating the config file, we should not include currently-set environment variables.
|
||||
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
||||
"""Migrate v4.0.0 config dictionary to a current config object.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v4.0.0 config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the migrated settings.
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for k, v in config_dict.items():
|
||||
# autocast was removed from precision in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
else:
|
||||
parsed_config_dict[k] = v
|
||||
if k == "schema_version":
|
||||
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
|
||||
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
||||
return config
|
||||
|
||||
|
||||
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
"""Load and migrate a config file to the latest version.
|
||||
|
||||
@ -432,29 +365,20 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
|
||||
assert isinstance(loaded_config_dict, dict)
|
||||
|
||||
if "InvokeAI" in loaded_config_dict:
|
||||
# This is a v3 config file, attempt to migrate it
|
||||
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
||||
try:
|
||||
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
|
||||
migrated_config = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
|
||||
except Exception as e:
|
||||
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
||||
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
|
||||
migrated_config.write_file(config_path)
|
||||
return migrated_config
|
||||
|
||||
if loaded_config_dict["schema_version"] == "4.0.0":
|
||||
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
|
||||
loaded_config_dict.write_file(config_path)
|
||||
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
||||
try:
|
||||
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
|
||||
migrated_config_dict = ConfigMigrator.migrate(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
|
||||
except Exception as e:
|
||||
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
||||
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
|
||||
|
||||
# Attempt to load as a v4 config file
|
||||
try:
|
||||
# Meta is not included in the model fields, so we need to validate it separately
|
||||
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
||||
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
|
||||
assert (
|
||||
config.schema_version == CONFIG_SCHEMA_VERSION
|
||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}"
|
||||
return config
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||
@ -504,6 +428,7 @@ def get_config() -> InvokeAIAppConfig:
|
||||
|
||||
if config.config_file_path.exists():
|
||||
config_from_file = load_and_migrate_config(config.config_file_path)
|
||||
config_from_file.write_file(config.config_file_path)
|
||||
# Clobbering here will overwrite any settings that were set via environment variables
|
||||
config.update_config(config_from_file, clobber=False)
|
||||
else:
|
||||
@ -512,3 +437,73 @@ def get_config() -> InvokeAIAppConfig:
|
||||
default_config.write_file(config.config_file_path, as_example=False)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
####################################################
|
||||
# VERSION MIGRATIONS
|
||||
####################################################
|
||||
|
||||
|
||||
@ConfigMigrator.register(from_version="0.0.0", to_version="4.0.0")
|
||||
def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Migrate a v3 config dictionary to a current config object.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v3 config file.
|
||||
|
||||
Returns:
|
||||
A dictionary of settings from a 4.0.0 config file.
|
||||
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for _category_name, category_dict in config_dict["InvokeAI"].items():
|
||||
for k, v in category_dict.items():
|
||||
# `outdir` was renamed to `outputs_dir` in v4
|
||||
if k == "outdir":
|
||||
parsed_config_dict["outputs_dir"] = v
|
||||
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
||||
if k == "max_cache_size" and "ram" not in category_dict:
|
||||
parsed_config_dict["ram"] = v
|
||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||
parsed_config_dict["vram"] = v
|
||||
# autocast was removed in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
if k == "conf_path":
|
||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
||||
if k == "legacy_conf_dir":
|
||||
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
|
||||
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
|
||||
# If if the incoming config has the default value, skip
|
||||
continue
|
||||
elif Path(v).name == "stable-diffusion":
|
||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
|
||||
else:
|
||||
# Else we do not attempt to migrate this setting
|
||||
parsed_config_dict["legacy_conf_dir"] = v
|
||||
elif k in InvokeAIAppConfig.model_fields:
|
||||
# skip unknown fields
|
||||
parsed_config_dict[k] = v
|
||||
return parsed_config_dict
|
||||
|
||||
|
||||
@ConfigMigrator.register(from_version="4.0.0", to_version="4.0.1")
|
||||
def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Migrate v4.0.0 config dictionary to v4.0.1.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v4.0.0 config file.
|
||||
|
||||
Returns:
|
||||
A dictionary of settings from a v4.0.1 config file
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for k, v in config_dict.items():
|
||||
# autocast was removed from precision in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
else:
|
||||
parsed_config_dict[k] = v
|
||||
return parsed_config_dict
|
||||
|
129
invokeai/app/services/config/config_migrate.py
Normal file
129
invokeai/app/services/config/config_migrate.py
Normal file
@ -0,0 +1,129 @@
|
||||
# Copyright 2024 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
Utility class for migrating among versions of the InvokeAI app config schema.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import TYPE_CHECKING, Any, Callable, List, TypeVar
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, field_validator
|
||||
from version_parser import Version
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
AppConfigDict = TypeVar("AppConfigDict", bound=dict[str, Any])
|
||||
|
||||
|
||||
class AppVersion(Version):
|
||||
"""Stringlike object that sorts like a version."""
|
||||
|
||||
def __hash__(self) -> int: # noqa D105
|
||||
return hash(str(self))
|
||||
|
||||
def __repr__(self) -> str: # noqa D105
|
||||
return f"AppVersion('{str(self)}')"
|
||||
|
||||
|
||||
class ConfigMigratorBase(ABC):
|
||||
"""This class allows migrators to register their input and output versions."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def register(
|
||||
cls, from_version: AppVersion, to_version: AppVersion
|
||||
) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]:
|
||||
"""Define a decorator which registers the migration between two versions."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def migrate(cls, config: AppConfigDict) -> AppConfigDict:
|
||||
"""
|
||||
Use the registered migration steps to bring config up to latest version.
|
||||
|
||||
:param config: The original configuration.
|
||||
:return: The new configuration, lifted up to the latest version.
|
||||
|
||||
As a side effect, the new configuration will be written to disk.
|
||||
"""
|
||||
|
||||
|
||||
class MigrationEntry(BaseModel):
|
||||
"""Defines an individual migration."""
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
from_version: AppVersion
|
||||
to_version: AppVersion
|
||||
function: Callable[[AppConfigDict], AppConfigDict]
|
||||
|
||||
@field_validator("from_version", "to_version", mode="before")
|
||||
@classmethod
|
||||
def _string_to_version(cls, v: str | AppVersion) -> AppVersion: # noqa D102
|
||||
if isinstance(v, str):
|
||||
return AppVersion(v)
|
||||
else:
|
||||
return v
|
||||
|
||||
|
||||
class ConfigMigrator(ConfigMigratorBase):
|
||||
"""This class allows migrators to register their input and output versions."""
|
||||
|
||||
_migrations: List[MigrationEntry] = []
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
cls,
|
||||
from_version: AppVersion | str,
|
||||
to_version: AppVersion | str,
|
||||
) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]:
|
||||
"""Define a decorator which registers the migration between two versions."""
|
||||
|
||||
def decorator(function: Callable[[AppConfigDict], AppConfigDict]) -> Callable[[AppConfigDict], AppConfigDict]:
|
||||
if from_version in cls._migrations:
|
||||
raise ValueError(
|
||||
f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered."
|
||||
)
|
||||
cls._migrations.append(MigrationEntry(from_version=from_version, to_version=to_version, function=function))
|
||||
return function
|
||||
|
||||
return decorator
|
||||
|
||||
@staticmethod
|
||||
def _check_for_overlaps(migrations: List[MigrationEntry]) -> None:
|
||||
current_version = AppVersion("0.0.0")
|
||||
for m in migrations:
|
||||
if current_version > m.from_version:
|
||||
raise ValueError(f"Version range overlap detected while processing function {m.function.__name__}")
|
||||
|
||||
@classmethod
|
||||
def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict:
|
||||
"""
|
||||
Use the registered migration steps to bring config up to latest version.
|
||||
|
||||
:param config: The original configuration.
|
||||
:return: The new configuration, lifted up to the latest version.
|
||||
|
||||
As a side effect, the new configuration will be written to disk.
|
||||
If an inconsistency in the registered migration steps' `from_version`
|
||||
and `to_version` parameters are identified, this will raise a
|
||||
ValueError exception.
|
||||
"""
|
||||
# Sort migrations by version number and raise a ValueError if
|
||||
# any version range overlaps are detected. Discontinuities are ok
|
||||
sorted_migrations = sorted(cls._migrations, key=lambda x: x.from_version)
|
||||
cls._check_for_overlaps(sorted_migrations)
|
||||
|
||||
if "InvokeAI" in config_dict:
|
||||
version = AppVersion("3.0.0")
|
||||
else:
|
||||
version = AppVersion(config_dict["schema_version"])
|
||||
|
||||
for migration in sorted_migrations:
|
||||
if version >= migration.from_version and version < migration.to_version:
|
||||
config_dict = migration.function(config_dict)
|
||||
version = migration.to_version
|
||||
|
||||
config_dict["schema_version"] = str(version)
|
||||
return config_dict
|
@ -90,6 +90,7 @@ dependencies = [
|
||||
"semver~=3.0.1",
|
||||
"send2trash",
|
||||
"test-tube~=0.7.5",
|
||||
"version-parser",
|
||||
"windows-curses; sys_platform=='win32'",
|
||||
]
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user