From a48abfacf4f3d8cf3c67c35be26a91b9990e78b7 Mon Sep 17 00:00:00 2001 From: Lincoln Stein Date: Thu, 2 May 2024 23:45:34 -0400 Subject: [PATCH] make config migrator into an instance; refactor location of get_config() --- invokeai/app/api_app.py | 2 +- invokeai/app/invocations/__init__.py | 2 +- invokeai/app/invocations/baseinvocation.py | 2 +- invokeai/app/services/config/__init__.py | 3 +- .../app/services/config/config_default.py | 166 ------------------ .../app/services/config/config_migrate.py | 121 +++++++++++-- invokeai/app/services/config/migrations.py | 108 ++++++++++++ .../app/services/shared/invocation_context.py | 2 +- .../app/services/shared/sqlite/sqlite_util.py | 2 +- .../sqlite_migrator/migrations/migration_8.py | 2 +- .../image_util/depth_anything/__init__.py | 2 +- .../image_util/dw_openpose/wholebody.py | 2 +- .../backend/image_util/infill_methods/lama.py | 2 +- .../image_util/infill_methods/patchmatch.py | 2 +- .../backend/image_util/invisible_watermark.py | 2 +- invokeai/backend/image_util/safety_checker.py | 2 +- .../stable_diffusion/diffusers_pipeline.py | 2 +- .../diffusion/shared_invokeai_diffusion.py | 2 +- invokeai/backend/util/devices.py | 2 +- invokeai/backend/util/logging.py | 3 +- tests/test_config.py | 102 ++++++++--- 21 files changed, 314 insertions(+), 219 deletions(-) create mode 100644 invokeai/app/services/config/migrations.py diff --git a/invokeai/app/api_app.py b/invokeai/app/api_app.py index ceaeb95147..2efd338c1e 100644 --- a/invokeai/app/api_app.py +++ b/invokeai/app/api_app.py @@ -26,7 +26,7 @@ import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import) import invokeai.frontend.web as web_dir from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.app.invocations.model import ModelIdentifierField -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.backend.util.devices import TorchDevice diff --git a/invokeai/app/invocations/__init__.py b/invokeai/app/invocations/__init__.py index cb1caa167e..f9b0932b04 100644 --- a/invokeai/app/invocations/__init__.py +++ b/invokeai/app/invocations/__init__.py @@ -3,7 +3,7 @@ import sys from importlib.util import module_from_spec, spec_from_file_location from pathlib import Path -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config custom_nodes_path = Path(get_config().custom_nodes_path) custom_nodes_path.mkdir(parents=True, exist_ok=True) diff --git a/invokeai/app/invocations/baseinvocation.py b/invokeai/app/invocations/baseinvocation.py index 40c7b41cae..ee4b88fa2e 100644 --- a/invokeai/app/invocations/baseinvocation.py +++ b/invokeai/app/invocations/baseinvocation.py @@ -33,7 +33,7 @@ from invokeai.app.invocations.fields import ( FieldKind, Input, ) -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.misc import uuid_string diff --git a/invokeai/app/services/config/__init__.py b/invokeai/app/services/config/__init__.py index 126692f08a..ac154386da 100644 --- a/invokeai/app/services/config/__init__.py +++ b/invokeai/app/services/config/__init__.py @@ -2,6 +2,7 @@ from invokeai.app.services.config.config_common import PagingArgumentParser -from .config_default import InvokeAIAppConfig, get_config +from .config_default import InvokeAIAppConfig +from .config_migrate import get_config __all__ = ["InvokeAIAppConfig", "get_config", "PagingArgumentParser"] diff --git a/invokeai/app/services/config/config_default.py b/invokeai/app/services/config/config_default.py index 0aedb54a37..06d0ea2783 100644 --- a/invokeai/app/services/config/config_default.py +++ b/invokeai/app/services/config/config_default.py @@ -3,11 +3,8 @@ from __future__ import annotations -import locale import os import re -import shutil -from functools import lru_cache from pathlib import Path from typing import Any, Literal, Optional @@ -16,11 +13,7 @@ import yaml from pydantic import BaseModel, Field, PrivateAttr, field_validator from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict -import invokeai.configs as model_configs from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS -from invokeai.frontend.cli.arg_parser import InvokeAIArgs - -from .config_migrate import ConfigMigrator INIT_FILE = Path("invokeai.yaml") DB_FILE = Path("invokeai.db") @@ -348,162 +341,3 @@ class DefaultInvokeAIAppConfig(InvokeAIAppConfig): file_secret_settings: PydanticBaseSettingsSource, ) -> tuple[PydanticBaseSettingsSource, ...]: return (init_settings,) - - -def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: - """Load and migrate a config file to the latest version. - - Args: - config_path: Path to the config file. - - Returns: - An instance of `InvokeAIAppConfig` with the loaded and migrated settings. - """ - assert config_path.suffix == ".yaml" - with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file: - loaded_config_dict = yaml.safe_load(file) - - assert isinstance(loaded_config_dict, dict) - - shutil.copy(config_path, config_path.with_suffix(".yaml.bak")) - try: - # loaded_config_dict could be the wrong shape, but we will catch all exceptions below - migrated_config_dict = ConfigMigrator.migrate(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType] - except Exception as e: - shutil.copy(config_path.with_suffix(".yaml.bak"), config_path) - raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e - - # Attempt to load as a v4 config file - try: - config = InvokeAIAppConfig.model_validate(migrated_config_dict) - assert ( - config.schema_version == CONFIG_SCHEMA_VERSION - ), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}" - return config - except Exception as e: - raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e - - -@lru_cache(maxsize=1) -def get_config() -> InvokeAIAppConfig: - """Get the global singleton app config. - - When first called, this function: - - Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file. - - Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint. - - Sets the root dir, if provided via CLI args. - - Logs in to HF if there is no valid token already. - - Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers). - - Reads and merges in settings from the config file if it exists, else writes out a default config file. - - On subsequent calls, the object is returned from the cache. - """ - # This object includes environment variables, as parsed by pydantic-settings - config = InvokeAIAppConfig() - - args = InvokeAIArgs.args - - # This flag serves as a proxy for whether the config was retrieved in the context of the full application or not. - # If it is False, we should just return a default config and not set the root, log in to HF, etc. - if not InvokeAIArgs.did_parse: - return config - - # Set CLI args - if root := getattr(args, "root", None): - config._root = Path(root) - if config_file := getattr(args, "config_file", None): - config._config_file = Path(config_file) - - # Create the example config file, with some extra example values provided - example_config = DefaultInvokeAIAppConfig() - example_config.remote_api_tokens = [ - URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"), - URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"), - ] - example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True) - - # Copy all legacy configs - We know `__path__[0]` is correct here - configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue] - shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True) - - if config.config_file_path.exists(): - config_from_file = load_and_migrate_config(config.config_file_path) - config_from_file.write_file(config.config_file_path) - # Clobbering here will overwrite any settings that were set via environment variables - config.update_config(config_from_file, clobber=False) - else: - # We should never write env vars to the config file - default_config = DefaultInvokeAIAppConfig() - default_config.write_file(config.config_file_path, as_example=False) - - return config - - -#################################################### -# VERSION MIGRATIONS -#################################################### - - -@ConfigMigrator.register(from_version="3.0.0", to_version="4.0.0") -def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]: - """Migrate a v3 config dictionary to a current config object. - - Args: - config_dict: A dictionary of settings from a v3 config file. - - Returns: - A dictionary of settings from a 4.0.0 config file. - - """ - parsed_config_dict: dict[str, Any] = {} - for _category_name, category_dict in config_dict["InvokeAI"].items(): - for k, v in category_dict.items(): - # `outdir` was renamed to `outputs_dir` in v4 - if k == "outdir": - parsed_config_dict["outputs_dir"] = v - # `max_cache_size` was renamed to `ram` some time in v3, but both names were used - if k == "max_cache_size" and "ram" not in category_dict: - parsed_config_dict["ram"] = v - # `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used - if k == "max_vram_cache_size" and "vram" not in category_dict: - parsed_config_dict["vram"] = v - # autocast was removed in v4.0.1 - if k == "precision" and v == "autocast": - parsed_config_dict["precision"] = "auto" - if k == "conf_path": - parsed_config_dict["legacy_models_yaml_path"] = v - if k == "legacy_conf_dir": - # The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows). - if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion": - # If if the incoming config has the default value, skip - continue - elif Path(v).name == "stable-diffusion": - # Else if the path ends in "stable-diffusion", we assume the parent is the new correct path. - parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent) - else: - # Else we do not attempt to migrate this setting - parsed_config_dict["legacy_conf_dir"] = v - elif k in InvokeAIAppConfig.model_fields: - # skip unknown fields - parsed_config_dict[k] = v - return parsed_config_dict - - -@ConfigMigrator.register(from_version="4.0.0", to_version="4.0.1") -def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]: - """Migrate v4.0.0 config dictionary to v4.0.1. - - Args: - config_dict: A dictionary of settings from a v4.0.0 config file. - - Returns: - A dictionary of settings from a v4.0.1 config file - """ - parsed_config_dict: dict[str, Any] = {} - for k, v in config_dict.items(): - # autocast was removed from precision in v4.0.1 - if k == "precision" and v == "autocast": - parsed_config_dict["precision"] = "auto" - else: - parsed_config_dict[k] = v - return parsed_config_dict diff --git a/invokeai/app/services/config/config_migrate.py b/invokeai/app/services/config/config_migrate.py index b3fe979d37..899fc43f39 100644 --- a/invokeai/app/services/config/config_migrate.py +++ b/invokeai/app/services/config/config_migrate.py @@ -4,12 +4,22 @@ Utility class for migrating among versions of the InvokeAI app config schema. """ +import locale +import shutil from dataclasses import dataclass -from typing import Any, Callable, List, TypeAlias +from functools import lru_cache +from pathlib import Path +from typing import Callable, List, TypeAlias +import yaml from packaging.version import Version -AppConfigDict: TypeAlias = dict[str, Any] +import invokeai.configs as model_configs +from invokeai.frontend.cli.arg_parser import InvokeAIArgs + +from .config_default import CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, URLRegexTokenPair +from .migrations import AppConfigDict, Migrations, MigrationsBase + MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict] @@ -25,22 +35,23 @@ class MigrationEntry: class ConfigMigrator: """This class allows migrators to register their input and output versions.""" - _migrations: List[MigrationEntry] = [] + def __init__(self, migrations: type[MigrationsBase]) -> None: + self._migrations: List[MigrationEntry] = [] + migrations.load(self) - @classmethod def register( - cls, + self, from_version: str, to_version: str, ) -> Callable[[MigrationFunction], MigrationFunction]: """Define a decorator which registers the migration between two versions.""" def decorator(function: MigrationFunction) -> MigrationFunction: - if any(from_version == m.from_version for m in cls._migrations): + if any((from_version == m.from_version) or (to_version == m.to_version) for m in self._migrations): raise ValueError( f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered." ) - cls._migrations.append( + self._migrations.append( MigrationEntry(from_version=Version(from_version), to_version=Version(to_version), function=function) ) return function @@ -57,8 +68,7 @@ class ConfigMigrator: ) current_version = m.to_version - @classmethod - def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict: + def run_migrations(self, config_dict: AppConfigDict) -> AppConfigDict: """ Use the registered migration steps to bring config up to latest version. @@ -72,8 +82,8 @@ class ConfigMigrator: """ # Sort migrations by version number and raise a ValueError if # any version range overlaps are detected. - sorted_migrations = sorted(cls._migrations, key=lambda x: x.from_version) - cls._check_for_discontinuities(sorted_migrations) + sorted_migrations = sorted(self._migrations, key=lambda x: x.from_version) + self._check_for_discontinuities(sorted_migrations) if "InvokeAI" in config_dict: version = Version("3.0.0") @@ -87,3 +97,92 @@ class ConfigMigrator: config_dict["schema_version"] = str(version) return config_dict + + +@lru_cache(maxsize=1) +def get_config() -> InvokeAIAppConfig: + """Get the global singleton app config. + + When first called, this function: + - Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file. + - Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint. + - Sets the root dir, if provided via CLI args. + - Logs in to HF if there is no valid token already. + - Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers). + - Reads and merges in settings from the config file if it exists, else writes out a default config file. + + On subsequent calls, the object is returned from the cache. + """ + # This object includes environment variables, as parsed by pydantic-settings + config = InvokeAIAppConfig() + + args = InvokeAIArgs.args + + # This flag serves as a proxy for whether the config was retrieved in the context of the full application or not. + # If it is False, we should just return a default config and not set the root, log in to HF, etc. + if not InvokeAIArgs.did_parse: + return config + + # Set CLI args + if root := getattr(args, "root", None): + config._root = Path(root) + if config_file := getattr(args, "config_file", None): + config._config_file = Path(config_file) + + # Create the example config file, with some extra example values provided + example_config = DefaultInvokeAIAppConfig() + example_config.remote_api_tokens = [ + URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"), + URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"), + ] + example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True) + + # Copy all legacy configs - We know `__path__[0]` is correct here + configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue] + shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True) + + if config.config_file_path.exists(): + config_from_file = load_and_migrate_config(config.config_file_path) + config_from_file.write_file(config.config_file_path) + # Clobbering here will overwrite any settings that were set via environment variables + config.update_config(config_from_file, clobber=False) + else: + # We should never write env vars to the config file + default_config = DefaultInvokeAIAppConfig() + default_config.write_file(config.config_file_path, as_example=False) + + return config + + +def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig: + """Load and migrate a config file to the latest version. + + Args: + config_path: Path to the config file. + + Returns: + An instance of `InvokeAIAppConfig` with the loaded and migrated settings. + """ + assert config_path.suffix == ".yaml" + with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file: + loaded_config_dict = yaml.safe_load(file) + + assert isinstance(loaded_config_dict, dict) + + shutil.copy(config_path, config_path.with_suffix(".yaml.bak")) + try: + migrator = ConfigMigrator(Migrations) + migrated_config_dict = migrator.run_migrations(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType] + except Exception as e: + shutil.copy(config_path.with_suffix(".yaml.bak"), config_path) + raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e + + # Attempt to load as a v4 config file + try: + config = InvokeAIAppConfig.model_validate(migrated_config_dict) + assert ( + config.schema_version == CONFIG_SCHEMA_VERSION + ), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}" + return config + except Exception as e: + raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e diff --git a/invokeai/app/services/config/migrations.py b/invokeai/app/services/config/migrations.py new file mode 100644 index 0000000000..4c6996d7db --- /dev/null +++ b/invokeai/app/services/config/migrations.py @@ -0,0 +1,108 @@ +# Copyright 2024 Lincoln D. Stein and the InvokeAI Development Team + +""" +Schema migrations to perform on an InvokeAIAppConfig object. + +The Migrations class defined in this module defines a series of +schema version migration steps for the InvokeAIConfig object. + +To define a new migration, add a migration function to +Migrations.load_migrations() following the existing examples. +""" + +from abc import ABC, abstractmethod +from pathlib import Path +from typing import TYPE_CHECKING, Any, TypeAlias + +from .config_default import InvokeAIAppConfig + +if TYPE_CHECKING: + from .config_migrate import ConfigMigrator + +AppConfigDict: TypeAlias = dict[str, Any] + + +class MigrationsBase(ABC): + """Define the config file migration steps to apply, abstract base class.""" + + @classmethod + @abstractmethod + def load(cls, migrator: "ConfigMigrator") -> None: + """Use the provided migrator to register the configuration migrations to be run.""" + + +class Migrations(MigrationsBase): + """Configuration migration steps to apply.""" + + @classmethod + def load(cls, migrator: "ConfigMigrator") -> None: + """Define migrations to perform.""" + + ################## + # 3.0.0 -> 4.0.0 # + ################## + @migrator.register(from_version="3.0.0", to_version="4.0.0") + def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]: + """Migrate a v3 config dictionary to a current config object. + + Args: + config_dict: A dictionary of settings from a v3 config file. + + Returns: + A dictionary of settings from a 4.0.0 config file. + + """ + parsed_config_dict: dict[str, Any] = {} + for _category_name, category_dict in config_dict["InvokeAI"].items(): + for k, v in category_dict.items(): + # `outdir` was renamed to `outputs_dir` in v4 + if k == "outdir": + parsed_config_dict["outputs_dir"] = v + # `max_cache_size` was renamed to `ram` some time in v3, but both names were used + if k == "max_cache_size" and "ram" not in category_dict: + parsed_config_dict["ram"] = v + # `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used + if k == "max_vram_cache_size" and "vram" not in category_dict: + parsed_config_dict["vram"] = v + # autocast was removed in v4.0.1 + if k == "precision" and v == "autocast": + parsed_config_dict["precision"] = "auto" + if k == "conf_path": + parsed_config_dict["legacy_models_yaml_path"] = v + if k == "legacy_conf_dir": + # The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows). + if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion": + # If if the incoming config has the default value, skip + continue + elif Path(v).name == "stable-diffusion": + # Else if the path ends in "stable-diffusion", we assume the parent is the new correct path. + parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent) + else: + # Else we do not attempt to migrate this setting + parsed_config_dict["legacy_conf_dir"] = v + elif k in InvokeAIAppConfig.model_fields: + # skip unknown fields + parsed_config_dict[k] = v + return parsed_config_dict + + ################## + # 4.0.0 -> 4.0.1 # + ################## + @migrator.register(from_version="4.0.0", to_version="4.0.1") + def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]: + """Migrate v4.0.0 config dictionary to v4.0.1. + + Args: + config_dict: A dictionary of settings from a v4.0.0 config file. + + Returns: + A dictionary of settings from a v4.0.1 config file + """ + parsed_config_dict: dict[str, Any] = {} + for k, v in config_dict.items(): + # autocast was removed from precision in v4.0.1 + if k == "precision" and v == "autocast": + parsed_config_dict["precision"] = "auto" + else: + parsed_config_dict[k] = v + return parsed_config_dict diff --git a/invokeai/app/services/shared/invocation_context.py b/invokeai/app/services/shared/invocation_context.py index 9994d663e5..68f7b11bcd 100644 --- a/invokeai/app/services/shared/invocation_context.py +++ b/invokeai/app/services/shared/invocation_context.py @@ -9,7 +9,7 @@ from torch import Tensor from invokeai.app.invocations.constants import IMAGE_MODES from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata from invokeai.app.services.boards.boards_common import BoardDTO -from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.invocation_services import InvocationServices diff --git a/invokeai/app/services/shared/sqlite/sqlite_util.py b/invokeai/app/services/shared/sqlite/sqlite_util.py index 1eed0b4409..2dcaaa8aed 100644 --- a/invokeai/app/services/shared/sqlite/sqlite_util.py +++ b/invokeai/app/services/shared/sqlite/sqlite_util.py @@ -1,6 +1,6 @@ from logging import Logger -from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1 diff --git a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_8.py b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_8.py index 154a5236ca..4fb8cf46ef 100644 --- a/invokeai/app/services/shared/sqlite_migrator/migrations/migration_8.py +++ b/invokeai/app/services/shared/sqlite_migrator/migrations/migration_8.py @@ -1,7 +1,7 @@ import sqlite3 from pathlib import Path -from invokeai.app.services.config.config_default import InvokeAIAppConfig +from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration diff --git a/invokeai/backend/image_util/depth_anything/__init__.py b/invokeai/backend/image_util/depth_anything/__init__.py index c854fba3f2..ac3692cec8 100644 --- a/invokeai/backend/image_util/depth_anything/__init__.py +++ b/invokeai/backend/image_util/depth_anything/__init__.py @@ -9,7 +9,7 @@ from einops import repeat from PIL import Image from torchvision.transforms import Compose -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2 from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize diff --git a/invokeai/backend/image_util/dw_openpose/wholebody.py b/invokeai/backend/image_util/dw_openpose/wholebody.py index 84f5afa989..750f480fb2 100644 --- a/invokeai/backend/image_util/dw_openpose/wholebody.py +++ b/invokeai/backend/image_util/dw_openpose/wholebody.py @@ -5,7 +5,7 @@ import numpy as np import onnxruntime as ort -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.backend.util.devices import TorchDevice diff --git a/invokeai/backend/image_util/infill_methods/lama.py b/invokeai/backend/image_util/infill_methods/lama.py index 4268ec773d..247dcb83c4 100644 --- a/invokeai/backend/image_util/infill_methods/lama.py +++ b/invokeai/backend/image_util/infill_methods/lama.py @@ -6,7 +6,7 @@ import torch from PIL import Image import invokeai.backend.util.logging as logger -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.backend.util.devices import TorchDevice diff --git a/invokeai/backend/image_util/infill_methods/patchmatch.py b/invokeai/backend/image_util/infill_methods/patchmatch.py index 7e9cdf8fa4..0ed05a1ab6 100644 --- a/invokeai/backend/image_util/infill_methods/patchmatch.py +++ b/invokeai/backend/image_util/infill_methods/patchmatch.py @@ -9,7 +9,7 @@ import numpy as np from PIL import Image import invokeai.backend.util.logging as logger -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config class PatchMatch: diff --git a/invokeai/backend/image_util/invisible_watermark.py b/invokeai/backend/image_util/invisible_watermark.py index 84342e442f..18c87dbffd 100644 --- a/invokeai/backend/image_util/invisible_watermark.py +++ b/invokeai/backend/image_util/invisible_watermark.py @@ -10,7 +10,7 @@ from imwatermark import WatermarkEncoder from PIL import Image import invokeai.backend.util.logging as logger -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config config = get_config() diff --git a/invokeai/backend/image_util/safety_checker.py b/invokeai/backend/image_util/safety_checker.py index 60dcd93fcc..afc51c241a 100644 --- a/invokeai/backend/image_util/safety_checker.py +++ b/invokeai/backend/image_util/safety_checker.py @@ -12,7 +12,7 @@ from PIL import Image from transformers import AutoFeatureExtractor import invokeai.backend.util.logging as logger -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.silence_warnings import SilenceWarnings diff --git a/invokeai/backend/stable_diffusion/diffusers_pipeline.py b/invokeai/backend/stable_diffusion/diffusers_pipeline.py index 8b90c815ae..cac3f58d08 100644 --- a/invokeai/backend/stable_diffusion/diffusers_pipeline.py +++ b/invokeai/backend/stable_diffusion/diffusers_pipeline.py @@ -20,7 +20,7 @@ from diffusers.utils.import_utils import is_xformers_available from pydantic import Field from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData diff --git a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py index f418133e49..c355208831 100644 --- a/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py +++ b/invokeai/backend/stable_diffusion/diffusion/shared_invokeai_diffusion.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Optional, Union import torch from typing_extensions import TypeAlias -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( IPAdapterData, Range, diff --git a/invokeai/backend/util/devices.py b/invokeai/backend/util/devices.py index e8380dc8bc..b7320bc9f0 100644 --- a/invokeai/backend/util/devices.py +++ b/invokeai/backend/util/devices.py @@ -3,7 +3,7 @@ from typing import Dict, Literal, Optional, Union import torch from deprecated import deprecated -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import get_config # legacy APIs TorchPrecisionNames = Literal["float32", "float16", "bfloat16"] diff --git a/invokeai/backend/util/logging.py b/invokeai/backend/util/logging.py index 968604eb3d..217a1a6cb8 100644 --- a/invokeai/backend/util/logging.py +++ b/invokeai/backend/util/logging.py @@ -180,8 +180,7 @@ import urllib.parse from pathlib import Path from typing import Any, Dict, Optional -from invokeai.app.services.config import InvokeAIAppConfig -from invokeai.app.services.config.config_default import get_config +from invokeai.app.services.config import InvokeAIAppConfig, get_config try: import syslog diff --git a/tests/test_config.py b/tests/test_config.py index 80c7ccc950..9b3c140e5e 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -12,10 +12,9 @@ from invokeai.app.services.config.config_default import ( CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, - get_config, - load_and_migrate_config, ) -from invokeai.app.services.config.config_migrate import ConfigMigrator +from invokeai.app.services.config.config_migrate import ConfigMigrator, get_config, load_and_migrate_config +from invokeai.app.services.config.migrations import Migrations, MigrationsBase from invokeai.app.services.shared.graph import Graph from invokeai.frontend.cli.arg_parser import InvokeAIArgs @@ -72,6 +71,68 @@ i like turtles """ +class GoodMigrations(MigrationsBase): + @classmethod + def load(cls, migrator: ConfigMigrator) -> None: + """Define migrations to perform.""" + + @migrator.register(from_version="3.0.0", to_version="10.0.0") + def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + @migrator.register(from_version="10.0.0", to_version="10.0.1") + def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + @migrator.register(from_version="10.0.1", to_version="10.0.2") + def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + +class BadMigrations1(MigrationsBase): + """This one fails because there is no path from 10.0.1 to 10.0.2""" + + @classmethod + def load(cls, migrator: ConfigMigrator) -> None: + """Define migrations to perform.""" + + @migrator.register(from_version="3.0.0", to_version="10.0.0") + def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + @migrator.register(from_version="10.0.0", to_version="10.0.1") + def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + @migrator.register(from_version="10.0.2", to_version="10.0.3") + def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + +class BadMigrations2(MigrationsBase): + """This one fails because the path to 10.0.2 is registered twice""" + + @classmethod + def load(cls, migrator: ConfigMigrator) -> None: + """Define migrations to perform.""" + + @migrator.register(from_version="3.0.0", to_version="10.0.0") + def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + @migrator.register(from_version="10.0.0", to_version="10.0.1") + def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + @migrator.register(from_version="10.0.1", to_version="10.0.2") + def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + @migrator.register(from_version="10.0.0", to_version="10.0.2") + def migration_4(config_dict: dict[str, Any]) -> dict[str, Any]: + return config_dict + + @pytest.fixture def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None: """This may be overkill since the current tests don't need the root dir to exist""" @@ -291,33 +352,26 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch def test_migration_check() -> None: - new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"}) + # Test the default set of migrations + migrator = ConfigMigrator(Migrations) + new_config = migrator.run_migrations({"schema_version": "4.0.0"}) assert new_config is not None assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION - # Does this execute at compile time or run time? - @ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION, to_version=CONFIG_SCHEMA_VERSION + ".1") - def ok_migration(config_dict: dict[str, Any]) -> dict[str, Any]: - return config_dict + # Test a custom set of migrations + migrator = ConfigMigrator(GoodMigrations) + new_config = migrator.run_migrations({"schema_version": "10.0.0"}) + assert new_config["schema_version"] == "10.0.2" - new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"}) - assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".1" - - @ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".2", to_version=CONFIG_SCHEMA_VERSION + ".3") - def bad_migration(config_dict: dict[str, Any]) -> dict[str, Any]: - return config_dict - - # Because there is no version for "*.1" => "*.2", this should fail. + # Test a migration that should fail validation + migrator = ConfigMigrator(BadMigrations1) with pytest.raises(ValueError): - ConfigMigrator.migrate({"schema_version": "4.0.0"}) + new_config = migrator.run_migrations({"schema_version": "10.0.0"}) - @ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".1", to_version=CONFIG_SCHEMA_VERSION + ".2") - def good_migration(config_dict: dict[str, Any]) -> dict[str, Any]: - return config_dict - - # should work now, because there is a continuous path to *.3 - new_config = ConfigMigrator.migrate(new_config) - assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".3" + # Test another bad migration + migrator = ConfigMigrator(BadMigrations2) + with pytest.raises(ValueError): + new_config = migrator.run_migrations({"schema_version": "10.0.0"}) @contextmanager