mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
add InvokeAIAppConfig schema migration system
This commit is contained in:
parent
a35386f24c
commit
6ad1948a44
@ -20,6 +20,8 @@ import invokeai.configs as model_configs
|
|||||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||||
|
|
||||||
|
from .config_migrate import ConfigMigrator
|
||||||
|
|
||||||
INIT_FILE = Path("invokeai.yaml")
|
INIT_FILE = Path("invokeai.yaml")
|
||||||
DB_FILE = Path("invokeai.db")
|
DB_FILE = Path("invokeai.db")
|
||||||
LEGACY_INIT_FILE = Path("invokeai.init")
|
LEGACY_INIT_FILE = Path("invokeai.init")
|
||||||
@ -348,75 +350,6 @@ class DefaultInvokeAIAppConfig(InvokeAIAppConfig):
|
|||||||
return (init_settings,)
|
return (init_settings,)
|
||||||
|
|
||||||
|
|
||||||
def migrate_v3_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
|
||||||
"""Migrate a v3 config dictionary to a current config object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dict: A dictionary of settings from a v3 config file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An instance of `InvokeAIAppConfig` with the migrated settings.
|
|
||||||
|
|
||||||
"""
|
|
||||||
parsed_config_dict: dict[str, Any] = {}
|
|
||||||
for _category_name, category_dict in config_dict["InvokeAI"].items():
|
|
||||||
for k, v in category_dict.items():
|
|
||||||
# `outdir` was renamed to `outputs_dir` in v4
|
|
||||||
if k == "outdir":
|
|
||||||
parsed_config_dict["outputs_dir"] = v
|
|
||||||
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
|
||||||
if k == "max_cache_size" and "ram" not in category_dict:
|
|
||||||
parsed_config_dict["ram"] = v
|
|
||||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
|
||||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
|
||||||
parsed_config_dict["vram"] = v
|
|
||||||
# autocast was removed in v4.0.1
|
|
||||||
if k == "precision" and v == "autocast":
|
|
||||||
parsed_config_dict["precision"] = "auto"
|
|
||||||
if k == "conf_path":
|
|
||||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
|
||||||
if k == "legacy_conf_dir":
|
|
||||||
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
|
|
||||||
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
|
|
||||||
# If if the incoming config has the default value, skip
|
|
||||||
continue
|
|
||||||
elif Path(v).name == "stable-diffusion":
|
|
||||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
|
||||||
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
|
|
||||||
else:
|
|
||||||
# Else we do not attempt to migrate this setting
|
|
||||||
parsed_config_dict["legacy_conf_dir"] = v
|
|
||||||
elif k in InvokeAIAppConfig.model_fields:
|
|
||||||
# skip unknown fields
|
|
||||||
parsed_config_dict[k] = v
|
|
||||||
# When migrating the config file, we should not include currently-set environment variables.
|
|
||||||
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def migrate_v4_0_0_config_dict(config_dict: dict[str, Any]) -> InvokeAIAppConfig:
|
|
||||||
"""Migrate v4.0.0 config dictionary to a current config object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dict: A dictionary of settings from a v4.0.0 config file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An instance of `InvokeAIAppConfig` with the migrated settings.
|
|
||||||
"""
|
|
||||||
parsed_config_dict: dict[str, Any] = {}
|
|
||||||
for k, v in config_dict.items():
|
|
||||||
# autocast was removed from precision in v4.0.1
|
|
||||||
if k == "precision" and v == "autocast":
|
|
||||||
parsed_config_dict["precision"] = "auto"
|
|
||||||
else:
|
|
||||||
parsed_config_dict[k] = v
|
|
||||||
if k == "schema_version":
|
|
||||||
parsed_config_dict[k] = CONFIG_SCHEMA_VERSION
|
|
||||||
config = DefaultInvokeAIAppConfig.model_validate(parsed_config_dict)
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||||
"""Load and migrate a config file to the latest version.
|
"""Load and migrate a config file to the latest version.
|
||||||
|
|
||||||
@ -432,29 +365,20 @@ def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
|||||||
|
|
||||||
assert isinstance(loaded_config_dict, dict)
|
assert isinstance(loaded_config_dict, dict)
|
||||||
|
|
||||||
if "InvokeAI" in loaded_config_dict:
|
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
||||||
# This is a v3 config file, attempt to migrate it
|
try:
|
||||||
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
|
||||||
try:
|
migrated_config_dict = ConfigMigrator.migrate(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
|
||||||
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
|
except Exception as e:
|
||||||
migrated_config = migrate_v3_config_dict(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
|
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
||||||
except Exception as e:
|
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
|
||||||
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
|
||||||
raise RuntimeError(f"Failed to load and migrate v3 config file {config_path}: {e}") from e
|
|
||||||
migrated_config.write_file(config_path)
|
|
||||||
return migrated_config
|
|
||||||
|
|
||||||
if loaded_config_dict["schema_version"] == "4.0.0":
|
|
||||||
loaded_config_dict = migrate_v4_0_0_config_dict(loaded_config_dict)
|
|
||||||
loaded_config_dict.write_file(config_path)
|
|
||||||
|
|
||||||
# Attempt to load as a v4 config file
|
# Attempt to load as a v4 config file
|
||||||
try:
|
try:
|
||||||
# Meta is not included in the model fields, so we need to validate it separately
|
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
|
||||||
config = InvokeAIAppConfig.model_validate(loaded_config_dict)
|
|
||||||
assert (
|
assert (
|
||||||
config.schema_version == CONFIG_SCHEMA_VERSION
|
config.schema_version == CONFIG_SCHEMA_VERSION
|
||||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION}: {config.schema_version}"
|
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}"
|
||||||
return config
|
return config
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||||
@ -504,6 +428,7 @@ def get_config() -> InvokeAIAppConfig:
|
|||||||
|
|
||||||
if config.config_file_path.exists():
|
if config.config_file_path.exists():
|
||||||
config_from_file = load_and_migrate_config(config.config_file_path)
|
config_from_file = load_and_migrate_config(config.config_file_path)
|
||||||
|
config_from_file.write_file(config.config_file_path)
|
||||||
# Clobbering here will overwrite any settings that were set via environment variables
|
# Clobbering here will overwrite any settings that were set via environment variables
|
||||||
config.update_config(config_from_file, clobber=False)
|
config.update_config(config_from_file, clobber=False)
|
||||||
else:
|
else:
|
||||||
@ -512,3 +437,73 @@ def get_config() -> InvokeAIAppConfig:
|
|||||||
default_config.write_file(config.config_file_path, as_example=False)
|
default_config.write_file(config.config_file_path, as_example=False)
|
||||||
|
|
||||||
return config
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
####################################################
|
||||||
|
# VERSION MIGRATIONS
|
||||||
|
####################################################
|
||||||
|
|
||||||
|
|
||||||
|
@ConfigMigrator.register(from_version="0.0.0", to_version="4.0.0")
|
||||||
|
def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Migrate a v3 config dictionary to a current config object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_dict: A dictionary of settings from a v3 config file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of settings from a 4.0.0 config file.
|
||||||
|
|
||||||
|
"""
|
||||||
|
parsed_config_dict: dict[str, Any] = {}
|
||||||
|
for _category_name, category_dict in config_dict["InvokeAI"].items():
|
||||||
|
for k, v in category_dict.items():
|
||||||
|
# `outdir` was renamed to `outputs_dir` in v4
|
||||||
|
if k == "outdir":
|
||||||
|
parsed_config_dict["outputs_dir"] = v
|
||||||
|
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
||||||
|
if k == "max_cache_size" and "ram" not in category_dict:
|
||||||
|
parsed_config_dict["ram"] = v
|
||||||
|
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||||
|
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||||
|
parsed_config_dict["vram"] = v
|
||||||
|
# autocast was removed in v4.0.1
|
||||||
|
if k == "precision" and v == "autocast":
|
||||||
|
parsed_config_dict["precision"] = "auto"
|
||||||
|
if k == "conf_path":
|
||||||
|
parsed_config_dict["legacy_models_yaml_path"] = v
|
||||||
|
if k == "legacy_conf_dir":
|
||||||
|
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
|
||||||
|
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
|
||||||
|
# If if the incoming config has the default value, skip
|
||||||
|
continue
|
||||||
|
elif Path(v).name == "stable-diffusion":
|
||||||
|
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||||
|
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
|
||||||
|
else:
|
||||||
|
# Else we do not attempt to migrate this setting
|
||||||
|
parsed_config_dict["legacy_conf_dir"] = v
|
||||||
|
elif k in InvokeAIAppConfig.model_fields:
|
||||||
|
# skip unknown fields
|
||||||
|
parsed_config_dict[k] = v
|
||||||
|
return parsed_config_dict
|
||||||
|
|
||||||
|
|
||||||
|
@ConfigMigrator.register(from_version="4.0.0", to_version="4.0.1")
|
||||||
|
def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Migrate v4.0.0 config dictionary to v4.0.1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_dict: A dictionary of settings from a v4.0.0 config file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of settings from a v4.0.1 config file
|
||||||
|
"""
|
||||||
|
parsed_config_dict: dict[str, Any] = {}
|
||||||
|
for k, v in config_dict.items():
|
||||||
|
# autocast was removed from precision in v4.0.1
|
||||||
|
if k == "precision" and v == "autocast":
|
||||||
|
parsed_config_dict["precision"] = "auto"
|
||||||
|
else:
|
||||||
|
parsed_config_dict[k] = v
|
||||||
|
return parsed_config_dict
|
||||||
|
129
invokeai/app/services/config/config_migrate.py
Normal file
129
invokeai/app/services/config/config_migrate.py
Normal file
@ -0,0 +1,129 @@
|
|||||||
|
# Copyright 2024 Lincoln D. Stein and the InvokeAI Development Team
|
||||||
|
|
||||||
|
"""
|
||||||
|
Utility class for migrating among versions of the InvokeAI app config schema.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from typing import TYPE_CHECKING, Any, Callable, List, TypeVar
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, field_validator
|
||||||
|
from version_parser import Version
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
pass
|
||||||
|
|
||||||
|
AppConfigDict = TypeVar("AppConfigDict", bound=dict[str, Any])
|
||||||
|
|
||||||
|
|
||||||
|
class AppVersion(Version):
|
||||||
|
"""Stringlike object that sorts like a version."""
|
||||||
|
|
||||||
|
def __hash__(self) -> int: # noqa D105
|
||||||
|
return hash(str(self))
|
||||||
|
|
||||||
|
def __repr__(self) -> str: # noqa D105
|
||||||
|
return f"AppVersion('{str(self)}')"
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigMigratorBase(ABC):
|
||||||
|
"""This class allows migrators to register their input and output versions."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def register(
|
||||||
|
cls, from_version: AppVersion, to_version: AppVersion
|
||||||
|
) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]:
|
||||||
|
"""Define a decorator which registers the migration between two versions."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def migrate(cls, config: AppConfigDict) -> AppConfigDict:
|
||||||
|
"""
|
||||||
|
Use the registered migration steps to bring config up to latest version.
|
||||||
|
|
||||||
|
:param config: The original configuration.
|
||||||
|
:return: The new configuration, lifted up to the latest version.
|
||||||
|
|
||||||
|
As a side effect, the new configuration will be written to disk.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class MigrationEntry(BaseModel):
|
||||||
|
"""Defines an individual migration."""
|
||||||
|
|
||||||
|
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||||
|
|
||||||
|
from_version: AppVersion
|
||||||
|
to_version: AppVersion
|
||||||
|
function: Callable[[AppConfigDict], AppConfigDict]
|
||||||
|
|
||||||
|
@field_validator("from_version", "to_version", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def _string_to_version(cls, v: str | AppVersion) -> AppVersion: # noqa D102
|
||||||
|
if isinstance(v, str):
|
||||||
|
return AppVersion(v)
|
||||||
|
else:
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
|
class ConfigMigrator(ConfigMigratorBase):
|
||||||
|
"""This class allows migrators to register their input and output versions."""
|
||||||
|
|
||||||
|
_migrations: List[MigrationEntry] = []
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def register(
|
||||||
|
cls,
|
||||||
|
from_version: AppVersion | str,
|
||||||
|
to_version: AppVersion | str,
|
||||||
|
) -> Callable[[Callable[[AppConfigDict], AppConfigDict]], Callable[[AppConfigDict], AppConfigDict]]:
|
||||||
|
"""Define a decorator which registers the migration between two versions."""
|
||||||
|
|
||||||
|
def decorator(function: Callable[[AppConfigDict], AppConfigDict]) -> Callable[[AppConfigDict], AppConfigDict]:
|
||||||
|
if from_version in cls._migrations:
|
||||||
|
raise ValueError(
|
||||||
|
f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered."
|
||||||
|
)
|
||||||
|
cls._migrations.append(MigrationEntry(from_version=from_version, to_version=to_version, function=function))
|
||||||
|
return function
|
||||||
|
|
||||||
|
return decorator
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def _check_for_overlaps(migrations: List[MigrationEntry]) -> None:
|
||||||
|
current_version = AppVersion("0.0.0")
|
||||||
|
for m in migrations:
|
||||||
|
if current_version > m.from_version:
|
||||||
|
raise ValueError(f"Version range overlap detected while processing function {m.function.__name__}")
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict:
|
||||||
|
"""
|
||||||
|
Use the registered migration steps to bring config up to latest version.
|
||||||
|
|
||||||
|
:param config: The original configuration.
|
||||||
|
:return: The new configuration, lifted up to the latest version.
|
||||||
|
|
||||||
|
As a side effect, the new configuration will be written to disk.
|
||||||
|
If an inconsistency in the registered migration steps' `from_version`
|
||||||
|
and `to_version` parameters are identified, this will raise a
|
||||||
|
ValueError exception.
|
||||||
|
"""
|
||||||
|
# Sort migrations by version number and raise a ValueError if
|
||||||
|
# any version range overlaps are detected. Discontinuities are ok
|
||||||
|
sorted_migrations = sorted(cls._migrations, key=lambda x: x.from_version)
|
||||||
|
cls._check_for_overlaps(sorted_migrations)
|
||||||
|
|
||||||
|
if "InvokeAI" in config_dict:
|
||||||
|
version = AppVersion("3.0.0")
|
||||||
|
else:
|
||||||
|
version = AppVersion(config_dict["schema_version"])
|
||||||
|
|
||||||
|
for migration in sorted_migrations:
|
||||||
|
if version >= migration.from_version and version < migration.to_version:
|
||||||
|
config_dict = migration.function(config_dict)
|
||||||
|
version = migration.to_version
|
||||||
|
|
||||||
|
config_dict["schema_version"] = str(version)
|
||||||
|
return config_dict
|
@ -90,6 +90,7 @@ dependencies = [
|
|||||||
"semver~=3.0.1",
|
"semver~=3.0.1",
|
||||||
"send2trash",
|
"send2trash",
|
||||||
"test-tube~=0.7.5",
|
"test-tube~=0.7.5",
|
||||||
|
"version-parser",
|
||||||
"windows-curses; sys_platform=='win32'",
|
"windows-curses; sys_platform=='win32'",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
Loading…
Reference in New Issue
Block a user