mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make config migrator into an instance; refactor location of get_config()
This commit is contained in:
parent
d5aee87684
commit
a48abfacf4
@ -26,7 +26,7 @@ import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
|||||||
import invokeai.frontend.web as web_dir
|
import invokeai.frontend.web as web_dir
|
||||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||||
from invokeai.app.invocations.model import ModelIdentifierField
|
from invokeai.app.invocations.model import ModelIdentifierField
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config import get_config
|
||||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
@ -3,7 +3,7 @@ import sys
|
|||||||
from importlib.util import module_from_spec, spec_from_file_location
|
from importlib.util import module_from_spec, spec_from_file_location
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config import get_config
|
||||||
|
|
||||||
custom_nodes_path = Path(get_config().custom_nodes_path)
|
custom_nodes_path = Path(get_config().custom_nodes_path)
|
||||||
custom_nodes_path.mkdir(parents=True, exist_ok=True)
|
custom_nodes_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
@ -33,7 +33,7 @@ from invokeai.app.invocations.fields import (
|
|||||||
FieldKind,
|
FieldKind,
|
||||||
Input,
|
Input,
|
||||||
)
|
)
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config import get_config
|
||||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||||
from invokeai.app.util.metaenum import MetaEnum
|
from invokeai.app.util.metaenum import MetaEnum
|
||||||
from invokeai.app.util.misc import uuid_string
|
from invokeai.app.util.misc import uuid_string
|
||||||
|
@ -2,6 +2,7 @@
|
|||||||
|
|
||||||
from invokeai.app.services.config.config_common import PagingArgumentParser
|
from invokeai.app.services.config.config_common import PagingArgumentParser
|
||||||
|
|
||||||
from .config_default import InvokeAIAppConfig, get_config
|
from .config_default import InvokeAIAppConfig
|
||||||
|
from .config_migrate import get_config
|
||||||
|
|
||||||
__all__ = ["InvokeAIAppConfig", "get_config", "PagingArgumentParser"]
|
__all__ = ["InvokeAIAppConfig", "get_config", "PagingArgumentParser"]
|
||||||
|
@ -3,11 +3,8 @@
|
|||||||
|
|
||||||
from __future__ import annotations
|
from __future__ import annotations
|
||||||
|
|
||||||
import locale
|
|
||||||
import os
|
import os
|
||||||
import re
|
import re
|
||||||
import shutil
|
|
||||||
from functools import lru_cache
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Literal, Optional
|
from typing import Any, Literal, Optional
|
||||||
|
|
||||||
@ -16,11 +13,7 @@ import yaml
|
|||||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
|
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
|
||||||
|
|
||||||
import invokeai.configs as model_configs
|
|
||||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
|
||||||
|
|
||||||
from .config_migrate import ConfigMigrator
|
|
||||||
|
|
||||||
INIT_FILE = Path("invokeai.yaml")
|
INIT_FILE = Path("invokeai.yaml")
|
||||||
DB_FILE = Path("invokeai.db")
|
DB_FILE = Path("invokeai.db")
|
||||||
@ -348,162 +341,3 @@ class DefaultInvokeAIAppConfig(InvokeAIAppConfig):
|
|||||||
file_secret_settings: PydanticBaseSettingsSource,
|
file_secret_settings: PydanticBaseSettingsSource,
|
||||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||||
return (init_settings,)
|
return (init_settings,)
|
||||||
|
|
||||||
|
|
||||||
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
|
||||||
"""Load and migrate a config file to the latest version.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_path: Path to the config file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
|
|
||||||
"""
|
|
||||||
assert config_path.suffix == ".yaml"
|
|
||||||
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
|
||||||
loaded_config_dict = yaml.safe_load(file)
|
|
||||||
|
|
||||||
assert isinstance(loaded_config_dict, dict)
|
|
||||||
|
|
||||||
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
|
||||||
try:
|
|
||||||
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
|
|
||||||
migrated_config_dict = ConfigMigrator.migrate(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
|
|
||||||
except Exception as e:
|
|
||||||
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
|
||||||
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
|
|
||||||
|
|
||||||
# Attempt to load as a v4 config file
|
|
||||||
try:
|
|
||||||
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
|
|
||||||
assert (
|
|
||||||
config.schema_version == CONFIG_SCHEMA_VERSION
|
|
||||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}"
|
|
||||||
return config
|
|
||||||
except Exception as e:
|
|
||||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
|
||||||
|
|
||||||
|
|
||||||
@lru_cache(maxsize=1)
|
|
||||||
def get_config() -> InvokeAIAppConfig:
|
|
||||||
"""Get the global singleton app config.
|
|
||||||
|
|
||||||
When first called, this function:
|
|
||||||
- Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file.
|
|
||||||
- Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint.
|
|
||||||
- Sets the root dir, if provided via CLI args.
|
|
||||||
- Logs in to HF if there is no valid token already.
|
|
||||||
- Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers).
|
|
||||||
- Reads and merges in settings from the config file if it exists, else writes out a default config file.
|
|
||||||
|
|
||||||
On subsequent calls, the object is returned from the cache.
|
|
||||||
"""
|
|
||||||
# This object includes environment variables, as parsed by pydantic-settings
|
|
||||||
config = InvokeAIAppConfig()
|
|
||||||
|
|
||||||
args = InvokeAIArgs.args
|
|
||||||
|
|
||||||
# This flag serves as a proxy for whether the config was retrieved in the context of the full application or not.
|
|
||||||
# If it is False, we should just return a default config and not set the root, log in to HF, etc.
|
|
||||||
if not InvokeAIArgs.did_parse:
|
|
||||||
return config
|
|
||||||
|
|
||||||
# Set CLI args
|
|
||||||
if root := getattr(args, "root", None):
|
|
||||||
config._root = Path(root)
|
|
||||||
if config_file := getattr(args, "config_file", None):
|
|
||||||
config._config_file = Path(config_file)
|
|
||||||
|
|
||||||
# Create the example config file, with some extra example values provided
|
|
||||||
example_config = DefaultInvokeAIAppConfig()
|
|
||||||
example_config.remote_api_tokens = [
|
|
||||||
URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"),
|
|
||||||
URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"),
|
|
||||||
]
|
|
||||||
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
|
|
||||||
|
|
||||||
# Copy all legacy configs - We know `__path__[0]` is correct here
|
|
||||||
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
|
|
||||||
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
|
|
||||||
|
|
||||||
if config.config_file_path.exists():
|
|
||||||
config_from_file = load_and_migrate_config(config.config_file_path)
|
|
||||||
config_from_file.write_file(config.config_file_path)
|
|
||||||
# Clobbering here will overwrite any settings that were set via environment variables
|
|
||||||
config.update_config(config_from_file, clobber=False)
|
|
||||||
else:
|
|
||||||
# We should never write env vars to the config file
|
|
||||||
default_config = DefaultInvokeAIAppConfig()
|
|
||||||
default_config.write_file(config.config_file_path, as_example=False)
|
|
||||||
|
|
||||||
return config
|
|
||||||
|
|
||||||
|
|
||||||
####################################################
|
|
||||||
# VERSION MIGRATIONS
|
|
||||||
####################################################
|
|
||||||
|
|
||||||
|
|
||||||
@ConfigMigrator.register(from_version="3.0.0", to_version="4.0.0")
|
|
||||||
def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Migrate a v3 config dictionary to a current config object.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dict: A dictionary of settings from a v3 config file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary of settings from a 4.0.0 config file.
|
|
||||||
|
|
||||||
"""
|
|
||||||
parsed_config_dict: dict[str, Any] = {}
|
|
||||||
for _category_name, category_dict in config_dict["InvokeAI"].items():
|
|
||||||
for k, v in category_dict.items():
|
|
||||||
# `outdir` was renamed to `outputs_dir` in v4
|
|
||||||
if k == "outdir":
|
|
||||||
parsed_config_dict["outputs_dir"] = v
|
|
||||||
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
|
||||||
if k == "max_cache_size" and "ram" not in category_dict:
|
|
||||||
parsed_config_dict["ram"] = v
|
|
||||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
|
||||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
|
||||||
parsed_config_dict["vram"] = v
|
|
||||||
# autocast was removed in v4.0.1
|
|
||||||
if k == "precision" and v == "autocast":
|
|
||||||
parsed_config_dict["precision"] = "auto"
|
|
||||||
if k == "conf_path":
|
|
||||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
|
||||||
if k == "legacy_conf_dir":
|
|
||||||
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
|
|
||||||
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
|
|
||||||
# If if the incoming config has the default value, skip
|
|
||||||
continue
|
|
||||||
elif Path(v).name == "stable-diffusion":
|
|
||||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
|
||||||
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
|
|
||||||
else:
|
|
||||||
# Else we do not attempt to migrate this setting
|
|
||||||
parsed_config_dict["legacy_conf_dir"] = v
|
|
||||||
elif k in InvokeAIAppConfig.model_fields:
|
|
||||||
# skip unknown fields
|
|
||||||
parsed_config_dict[k] = v
|
|
||||||
return parsed_config_dict
|
|
||||||
|
|
||||||
|
|
||||||
@ConfigMigrator.register(from_version="4.0.0", to_version="4.0.1")
|
|
||||||
def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
"""Migrate v4.0.0 config dictionary to v4.0.1.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
config_dict: A dictionary of settings from a v4.0.0 config file.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
A dictionary of settings from a v4.0.1 config file
|
|
||||||
"""
|
|
||||||
parsed_config_dict: dict[str, Any] = {}
|
|
||||||
for k, v in config_dict.items():
|
|
||||||
# autocast was removed from precision in v4.0.1
|
|
||||||
if k == "precision" and v == "autocast":
|
|
||||||
parsed_config_dict["precision"] = "auto"
|
|
||||||
else:
|
|
||||||
parsed_config_dict[k] = v
|
|
||||||
return parsed_config_dict
|
|
||||||
|
@ -4,12 +4,22 @@
|
|||||||
Utility class for migrating among versions of the InvokeAI app config schema.
|
Utility class for migrating among versions of the InvokeAI app config schema.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
import locale
|
||||||
|
import shutil
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Callable, List, TypeAlias
|
from functools import lru_cache
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, List, TypeAlias
|
||||||
|
|
||||||
|
import yaml
|
||||||
from packaging.version import Version
|
from packaging.version import Version
|
||||||
|
|
||||||
AppConfigDict: TypeAlias = dict[str, Any]
|
import invokeai.configs as model_configs
|
||||||
|
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||||
|
|
||||||
|
from .config_default import CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, URLRegexTokenPair
|
||||||
|
from .migrations import AppConfigDict, Migrations, MigrationsBase
|
||||||
|
|
||||||
MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
|
MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
|
||||||
|
|
||||||
|
|
||||||
@ -25,22 +35,23 @@ class MigrationEntry:
|
|||||||
class ConfigMigrator:
|
class ConfigMigrator:
|
||||||
"""This class allows migrators to register their input and output versions."""
|
"""This class allows migrators to register their input and output versions."""
|
||||||
|
|
||||||
_migrations: List[MigrationEntry] = []
|
def __init__(self, migrations: type[MigrationsBase]) -> None:
|
||||||
|
self._migrations: List[MigrationEntry] = []
|
||||||
|
migrations.load(self)
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def register(
|
def register(
|
||||||
cls,
|
self,
|
||||||
from_version: str,
|
from_version: str,
|
||||||
to_version: str,
|
to_version: str,
|
||||||
) -> Callable[[MigrationFunction], MigrationFunction]:
|
) -> Callable[[MigrationFunction], MigrationFunction]:
|
||||||
"""Define a decorator which registers the migration between two versions."""
|
"""Define a decorator which registers the migration between two versions."""
|
||||||
|
|
||||||
def decorator(function: MigrationFunction) -> MigrationFunction:
|
def decorator(function: MigrationFunction) -> MigrationFunction:
|
||||||
if any(from_version == m.from_version for m in cls._migrations):
|
if any((from_version == m.from_version) or (to_version == m.to_version) for m in self._migrations):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered."
|
f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered."
|
||||||
)
|
)
|
||||||
cls._migrations.append(
|
self._migrations.append(
|
||||||
MigrationEntry(from_version=Version(from_version), to_version=Version(to_version), function=function)
|
MigrationEntry(from_version=Version(from_version), to_version=Version(to_version), function=function)
|
||||||
)
|
)
|
||||||
return function
|
return function
|
||||||
@ -57,8 +68,7 @@ class ConfigMigrator:
|
|||||||
)
|
)
|
||||||
current_version = m.to_version
|
current_version = m.to_version
|
||||||
|
|
||||||
@classmethod
|
def run_migrations(self, config_dict: AppConfigDict) -> AppConfigDict:
|
||||||
def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict:
|
|
||||||
"""
|
"""
|
||||||
Use the registered migration steps to bring config up to latest version.
|
Use the registered migration steps to bring config up to latest version.
|
||||||
|
|
||||||
@ -72,8 +82,8 @@ class ConfigMigrator:
|
|||||||
"""
|
"""
|
||||||
# Sort migrations by version number and raise a ValueError if
|
# Sort migrations by version number and raise a ValueError if
|
||||||
# any version range overlaps are detected.
|
# any version range overlaps are detected.
|
||||||
sorted_migrations = sorted(cls._migrations, key=lambda x: x.from_version)
|
sorted_migrations = sorted(self._migrations, key=lambda x: x.from_version)
|
||||||
cls._check_for_discontinuities(sorted_migrations)
|
self._check_for_discontinuities(sorted_migrations)
|
||||||
|
|
||||||
if "InvokeAI" in config_dict:
|
if "InvokeAI" in config_dict:
|
||||||
version = Version("3.0.0")
|
version = Version("3.0.0")
|
||||||
@ -87,3 +97,92 @@ class ConfigMigrator:
|
|||||||
|
|
||||||
config_dict["schema_version"] = str(version)
|
config_dict["schema_version"] = str(version)
|
||||||
return config_dict
|
return config_dict
|
||||||
|
|
||||||
|
|
||||||
|
@lru_cache(maxsize=1)
|
||||||
|
def get_config() -> InvokeAIAppConfig:
|
||||||
|
"""Get the global singleton app config.
|
||||||
|
|
||||||
|
When first called, this function:
|
||||||
|
- Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file.
|
||||||
|
- Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint.
|
||||||
|
- Sets the root dir, if provided via CLI args.
|
||||||
|
- Logs in to HF if there is no valid token already.
|
||||||
|
- Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers).
|
||||||
|
- Reads and merges in settings from the config file if it exists, else writes out a default config file.
|
||||||
|
|
||||||
|
On subsequent calls, the object is returned from the cache.
|
||||||
|
"""
|
||||||
|
# This object includes environment variables, as parsed by pydantic-settings
|
||||||
|
config = InvokeAIAppConfig()
|
||||||
|
|
||||||
|
args = InvokeAIArgs.args
|
||||||
|
|
||||||
|
# This flag serves as a proxy for whether the config was retrieved in the context of the full application or not.
|
||||||
|
# If it is False, we should just return a default config and not set the root, log in to HF, etc.
|
||||||
|
if not InvokeAIArgs.did_parse:
|
||||||
|
return config
|
||||||
|
|
||||||
|
# Set CLI args
|
||||||
|
if root := getattr(args, "root", None):
|
||||||
|
config._root = Path(root)
|
||||||
|
if config_file := getattr(args, "config_file", None):
|
||||||
|
config._config_file = Path(config_file)
|
||||||
|
|
||||||
|
# Create the example config file, with some extra example values provided
|
||||||
|
example_config = DefaultInvokeAIAppConfig()
|
||||||
|
example_config.remote_api_tokens = [
|
||||||
|
URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"),
|
||||||
|
URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"),
|
||||||
|
]
|
||||||
|
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
|
||||||
|
|
||||||
|
# Copy all legacy configs - We know `__path__[0]` is correct here
|
||||||
|
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
|
||||||
|
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
|
||||||
|
|
||||||
|
if config.config_file_path.exists():
|
||||||
|
config_from_file = load_and_migrate_config(config.config_file_path)
|
||||||
|
config_from_file.write_file(config.config_file_path)
|
||||||
|
# Clobbering here will overwrite any settings that were set via environment variables
|
||||||
|
config.update_config(config_from_file, clobber=False)
|
||||||
|
else:
|
||||||
|
# We should never write env vars to the config file
|
||||||
|
default_config = DefaultInvokeAIAppConfig()
|
||||||
|
default_config.write_file(config.config_file_path, as_example=False)
|
||||||
|
|
||||||
|
return config
|
||||||
|
|
||||||
|
|
||||||
|
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||||
|
"""Load and migrate a config file to the latest version.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_path: Path to the config file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
|
||||||
|
"""
|
||||||
|
assert config_path.suffix == ".yaml"
|
||||||
|
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||||
|
loaded_config_dict = yaml.safe_load(file)
|
||||||
|
|
||||||
|
assert isinstance(loaded_config_dict, dict)
|
||||||
|
|
||||||
|
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
||||||
|
try:
|
||||||
|
migrator = ConfigMigrator(Migrations)
|
||||||
|
migrated_config_dict = migrator.run_migrations(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
|
||||||
|
except Exception as e:
|
||||||
|
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
||||||
|
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
|
||||||
|
|
||||||
|
# Attempt to load as a v4 config file
|
||||||
|
try:
|
||||||
|
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
|
||||||
|
assert (
|
||||||
|
config.schema_version == CONFIG_SCHEMA_VERSION
|
||||||
|
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}"
|
||||||
|
return config
|
||||||
|
except Exception as e:
|
||||||
|
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||||
|
108
invokeai/app/services/config/migrations.py
Normal file
108
invokeai/app/services/config/migrations.py
Normal file
@ -0,0 +1,108 @@
|
|||||||
|
# Copyright 2024 Lincoln D. Stein and the InvokeAI Development Team
|
||||||
|
|
||||||
|
"""
|
||||||
|
Schema migrations to perform on an InvokeAIAppConfig object.
|
||||||
|
|
||||||
|
The Migrations class defined in this module defines a series of
|
||||||
|
schema version migration steps for the InvokeAIConfig object.
|
||||||
|
|
||||||
|
To define a new migration, add a migration function to
|
||||||
|
Migrations.load_migrations() following the existing examples.
|
||||||
|
"""
|
||||||
|
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import TYPE_CHECKING, Any, TypeAlias
|
||||||
|
|
||||||
|
from .config_default import InvokeAIAppConfig
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from .config_migrate import ConfigMigrator
|
||||||
|
|
||||||
|
AppConfigDict: TypeAlias = dict[str, Any]
|
||||||
|
|
||||||
|
|
||||||
|
class MigrationsBase(ABC):
|
||||||
|
"""Define the config file migration steps to apply, abstract base class."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
@abstractmethod
|
||||||
|
def load(cls, migrator: "ConfigMigrator") -> None:
|
||||||
|
"""Use the provided migrator to register the configuration migrations to be run."""
|
||||||
|
|
||||||
|
|
||||||
|
class Migrations(MigrationsBase):
|
||||||
|
"""Configuration migration steps to apply."""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, migrator: "ConfigMigrator") -> None:
|
||||||
|
"""Define migrations to perform."""
|
||||||
|
|
||||||
|
##################
|
||||||
|
# 3.0.0 -> 4.0.0 #
|
||||||
|
##################
|
||||||
|
@migrator.register(from_version="3.0.0", to_version="4.0.0")
|
||||||
|
def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Migrate a v3 config dictionary to a current config object.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_dict: A dictionary of settings from a v3 config file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of settings from a 4.0.0 config file.
|
||||||
|
|
||||||
|
"""
|
||||||
|
parsed_config_dict: dict[str, Any] = {}
|
||||||
|
for _category_name, category_dict in config_dict["InvokeAI"].items():
|
||||||
|
for k, v in category_dict.items():
|
||||||
|
# `outdir` was renamed to `outputs_dir` in v4
|
||||||
|
if k == "outdir":
|
||||||
|
parsed_config_dict["outputs_dir"] = v
|
||||||
|
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
||||||
|
if k == "max_cache_size" and "ram" not in category_dict:
|
||||||
|
parsed_config_dict["ram"] = v
|
||||||
|
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||||
|
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||||
|
parsed_config_dict["vram"] = v
|
||||||
|
# autocast was removed in v4.0.1
|
||||||
|
if k == "precision" and v == "autocast":
|
||||||
|
parsed_config_dict["precision"] = "auto"
|
||||||
|
if k == "conf_path":
|
||||||
|
parsed_config_dict["legacy_models_yaml_path"] = v
|
||||||
|
if k == "legacy_conf_dir":
|
||||||
|
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
|
||||||
|
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
|
||||||
|
# If if the incoming config has the default value, skip
|
||||||
|
continue
|
||||||
|
elif Path(v).name == "stable-diffusion":
|
||||||
|
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||||
|
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
|
||||||
|
else:
|
||||||
|
# Else we do not attempt to migrate this setting
|
||||||
|
parsed_config_dict["legacy_conf_dir"] = v
|
||||||
|
elif k in InvokeAIAppConfig.model_fields:
|
||||||
|
# skip unknown fields
|
||||||
|
parsed_config_dict[k] = v
|
||||||
|
return parsed_config_dict
|
||||||
|
|
||||||
|
##################
|
||||||
|
# 4.0.0 -> 4.0.1 #
|
||||||
|
##################
|
||||||
|
@migrator.register(from_version="4.0.0", to_version="4.0.1")
|
||||||
|
def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
"""Migrate v4.0.0 config dictionary to v4.0.1.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config_dict: A dictionary of settings from a v4.0.0 config file.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary of settings from a v4.0.1 config file
|
||||||
|
"""
|
||||||
|
parsed_config_dict: dict[str, Any] = {}
|
||||||
|
for k, v in config_dict.items():
|
||||||
|
# autocast was removed from precision in v4.0.1
|
||||||
|
if k == "precision" and v == "autocast":
|
||||||
|
parsed_config_dict["precision"] = "auto"
|
||||||
|
else:
|
||||||
|
parsed_config_dict[k] = v
|
||||||
|
return parsed_config_dict
|
@ -9,7 +9,7 @@ from torch import Tensor
|
|||||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||||
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
|
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
|
||||||
from invokeai.app.services.boards.boards_common import BoardDTO
|
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||||
from invokeai.app.services.images.images_common import ImageDTO
|
from invokeai.app.services.images.images_common import ImageDTO
|
||||||
from invokeai.app.services.invocation_services import InvocationServices
|
from invokeai.app.services.invocation_services import InvocationServices
|
||||||
|
@ -1,6 +1,6 @@
|
|||||||
from logging import Logger
|
from logging import Logger
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1
|
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1
|
||||||
|
@ -1,7 +1,7 @@
|
|||||||
import sqlite3
|
import sqlite3
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig
|
||||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||||
|
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ from einops import repeat
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
from torchvision.transforms import Compose
|
from torchvision.transforms import Compose
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config import get_config
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||||
|
@ -5,7 +5,7 @@
|
|||||||
import numpy as np
|
import numpy as np
|
||||||
import onnxruntime as ort
|
import onnxruntime as ort
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config import get_config
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
@ -6,7 +6,7 @@ import torch
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config import get_config
|
||||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
|
|
||||||
|
@ -9,7 +9,7 @@ import numpy as np
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config import get_config
|
||||||
|
|
||||||
|
|
||||||
class PatchMatch:
|
class PatchMatch:
|
||||||
|
@ -10,7 +10,7 @@ from imwatermark import WatermarkEncoder
|
|||||||
from PIL import Image
|
from PIL import Image
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config import get_config
|
||||||
|
|
||||||
config = get_config()
|
config = get_config()
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ from PIL import Image
|
|||||||
from transformers import AutoFeatureExtractor
|
from transformers import AutoFeatureExtractor
|
||||||
|
|
||||||
import invokeai.backend.util.logging as logger
|
import invokeai.backend.util.logging as logger
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config import get_config
|
||||||
from invokeai.backend.util.devices import TorchDevice
|
from invokeai.backend.util.devices import TorchDevice
|
||||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||||
|
|
||||||
|
@ -20,7 +20,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
|||||||
from pydantic import Field
|
from pydantic import Field
|
||||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config import get_config
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
||||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||||
|
@ -6,7 +6,7 @@ from typing import Any, Callable, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
from typing_extensions import TypeAlias
|
from typing_extensions import TypeAlias
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config import get_config
|
||||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||||
IPAdapterData,
|
IPAdapterData,
|
||||||
Range,
|
Range,
|
||||||
|
@ -3,7 +3,7 @@ from typing import Dict, Literal, Optional, Union
|
|||||||
import torch
|
import torch
|
||||||
from deprecated import deprecated
|
from deprecated import deprecated
|
||||||
|
|
||||||
from invokeai.app.services.config.config_default import get_config
|
from invokeai.app.services.config import get_config
|
||||||
|
|
||||||
# legacy APIs
|
# legacy APIs
|
||||||
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
|
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
|
||||||
|
@ -180,8 +180,7 @@ import urllib.parse
|
|||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
from invokeai.app.services.config import InvokeAIAppConfig
|
from invokeai.app.services.config import InvokeAIAppConfig, get_config
|
||||||
from invokeai.app.services.config.config_default import get_config
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
import syslog
|
import syslog
|
||||||
|
@ -12,10 +12,9 @@ from invokeai.app.services.config.config_default import (
|
|||||||
CONFIG_SCHEMA_VERSION,
|
CONFIG_SCHEMA_VERSION,
|
||||||
DefaultInvokeAIAppConfig,
|
DefaultInvokeAIAppConfig,
|
||||||
InvokeAIAppConfig,
|
InvokeAIAppConfig,
|
||||||
get_config,
|
|
||||||
load_and_migrate_config,
|
|
||||||
)
|
)
|
||||||
from invokeai.app.services.config.config_migrate import ConfigMigrator
|
from invokeai.app.services.config.config_migrate import ConfigMigrator, get_config, load_and_migrate_config
|
||||||
|
from invokeai.app.services.config.migrations import Migrations, MigrationsBase
|
||||||
from invokeai.app.services.shared.graph import Graph
|
from invokeai.app.services.shared.graph import Graph
|
||||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||||
|
|
||||||
@ -72,6 +71,68 @@ i like turtles
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class GoodMigrations(MigrationsBase):
|
||||||
|
@classmethod
|
||||||
|
def load(cls, migrator: ConfigMigrator) -> None:
|
||||||
|
"""Define migrations to perform."""
|
||||||
|
|
||||||
|
@migrator.register(from_version="3.0.0", to_version="10.0.0")
|
||||||
|
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
return config_dict
|
||||||
|
|
||||||
|
@migrator.register(from_version="10.0.0", to_version="10.0.1")
|
||||||
|
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
return config_dict
|
||||||
|
|
||||||
|
@migrator.register(from_version="10.0.1", to_version="10.0.2")
|
||||||
|
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
return config_dict
|
||||||
|
|
||||||
|
|
||||||
|
class BadMigrations1(MigrationsBase):
|
||||||
|
"""This one fails because there is no path from 10.0.1 to 10.0.2"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, migrator: ConfigMigrator) -> None:
|
||||||
|
"""Define migrations to perform."""
|
||||||
|
|
||||||
|
@migrator.register(from_version="3.0.0", to_version="10.0.0")
|
||||||
|
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
return config_dict
|
||||||
|
|
||||||
|
@migrator.register(from_version="10.0.0", to_version="10.0.1")
|
||||||
|
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
return config_dict
|
||||||
|
|
||||||
|
@migrator.register(from_version="10.0.2", to_version="10.0.3")
|
||||||
|
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
return config_dict
|
||||||
|
|
||||||
|
|
||||||
|
class BadMigrations2(MigrationsBase):
|
||||||
|
"""This one fails because the path to 10.0.2 is registered twice"""
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def load(cls, migrator: ConfigMigrator) -> None:
|
||||||
|
"""Define migrations to perform."""
|
||||||
|
|
||||||
|
@migrator.register(from_version="3.0.0", to_version="10.0.0")
|
||||||
|
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
return config_dict
|
||||||
|
|
||||||
|
@migrator.register(from_version="10.0.0", to_version="10.0.1")
|
||||||
|
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
return config_dict
|
||||||
|
|
||||||
|
@migrator.register(from_version="10.0.1", to_version="10.0.2")
|
||||||
|
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
return config_dict
|
||||||
|
|
||||||
|
@migrator.register(from_version="10.0.0", to_version="10.0.2")
|
||||||
|
def migration_4(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||||
|
return config_dict
|
||||||
|
|
||||||
|
|
||||||
@pytest.fixture
|
@pytest.fixture
|
||||||
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
|
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
|
||||||
"""This may be overkill since the current tests don't need the root dir to exist"""
|
"""This may be overkill since the current tests don't need the root dir to exist"""
|
||||||
@ -291,33 +352,26 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch
|
|||||||
|
|
||||||
|
|
||||||
def test_migration_check() -> None:
|
def test_migration_check() -> None:
|
||||||
new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"})
|
# Test the default set of migrations
|
||||||
|
migrator = ConfigMigrator(Migrations)
|
||||||
|
new_config = migrator.run_migrations({"schema_version": "4.0.0"})
|
||||||
assert new_config is not None
|
assert new_config is not None
|
||||||
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION
|
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION
|
||||||
|
|
||||||
# Does this execute at compile time or run time?
|
# Test a custom set of migrations
|
||||||
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION, to_version=CONFIG_SCHEMA_VERSION + ".1")
|
migrator = ConfigMigrator(GoodMigrations)
|
||||||
def ok_migration(config_dict: dict[str, Any]) -> dict[str, Any]:
|
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
|
||||||
return config_dict
|
assert new_config["schema_version"] == "10.0.2"
|
||||||
|
|
||||||
new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"})
|
# Test a migration that should fail validation
|
||||||
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".1"
|
migrator = ConfigMigrator(BadMigrations1)
|
||||||
|
|
||||||
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".2", to_version=CONFIG_SCHEMA_VERSION + ".3")
|
|
||||||
def bad_migration(config_dict: dict[str, Any]) -> dict[str, Any]:
|
|
||||||
return config_dict
|
|
||||||
|
|
||||||
# Because there is no version for "*.1" => "*.2", this should fail.
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
ConfigMigrator.migrate({"schema_version": "4.0.0"})
|
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
|
||||||
|
|
||||||
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".1", to_version=CONFIG_SCHEMA_VERSION + ".2")
|
# Test another bad migration
|
||||||
def good_migration(config_dict: dict[str, Any]) -> dict[str, Any]:
|
migrator = ConfigMigrator(BadMigrations2)
|
||||||
return config_dict
|
with pytest.raises(ValueError):
|
||||||
|
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
|
||||||
# should work now, because there is a continuous path to *.3
|
|
||||||
new_config = ConfigMigrator.migrate(new_config)
|
|
||||||
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".3"
|
|
||||||
|
|
||||||
|
|
||||||
@contextmanager
|
@contextmanager
|
||||||
|
Loading…
x
Reference in New Issue
Block a user