mirror of
https://github.com/invoke-ai/InvokeAI
synced 2024-08-30 20:32:17 +00:00
make config migrator into an instance; refactor location of get_config()
This commit is contained in:
parent
d5aee87684
commit
a48abfacf4
@ -26,7 +26,7 @@ import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
|
||||
import invokeai.frontend.web as web_dir
|
||||
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
|
||||
from invokeai.app.invocations.model import ModelIdentifierField
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.session_processor.session_processor_common import ProgressImage
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
@ -3,7 +3,7 @@ import sys
|
||||
from importlib.util import module_from_spec, spec_from_file_location
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
|
||||
custom_nodes_path = Path(get_config().custom_nodes_path)
|
||||
custom_nodes_path.mkdir(parents=True, exist_ok=True)
|
||||
|
@ -33,7 +33,7 @@ from invokeai.app.invocations.fields import (
|
||||
FieldKind,
|
||||
Input,
|
||||
)
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.services.shared.invocation_context import InvocationContext
|
||||
from invokeai.app.util.metaenum import MetaEnum
|
||||
from invokeai.app.util.misc import uuid_string
|
||||
|
@ -2,6 +2,7 @@
|
||||
|
||||
from invokeai.app.services.config.config_common import PagingArgumentParser
|
||||
|
||||
from .config_default import InvokeAIAppConfig, get_config
|
||||
from .config_default import InvokeAIAppConfig
|
||||
from .config_migrate import get_config
|
||||
|
||||
__all__ = ["InvokeAIAppConfig", "get_config", "PagingArgumentParser"]
|
||||
|
@ -3,11 +3,8 @@
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import locale
|
||||
import os
|
||||
import re
|
||||
import shutil
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Any, Literal, Optional
|
||||
|
||||
@ -16,11 +13,7 @@ import yaml
|
||||
from pydantic import BaseModel, Field, PrivateAttr, field_validator
|
||||
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
|
||||
|
||||
import invokeai.configs as model_configs
|
||||
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
from .config_migrate import ConfigMigrator
|
||||
|
||||
INIT_FILE = Path("invokeai.yaml")
|
||||
DB_FILE = Path("invokeai.db")
|
||||
@ -348,162 +341,3 @@ class DefaultInvokeAIAppConfig(InvokeAIAppConfig):
|
||||
file_secret_settings: PydanticBaseSettingsSource,
|
||||
) -> tuple[PydanticBaseSettingsSource, ...]:
|
||||
return (init_settings,)
|
||||
|
||||
|
||||
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
"""Load and migrate a config file to the latest version.
|
||||
|
||||
Args:
|
||||
config_path: Path to the config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
|
||||
"""
|
||||
assert config_path.suffix == ".yaml"
|
||||
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||
loaded_config_dict = yaml.safe_load(file)
|
||||
|
||||
assert isinstance(loaded_config_dict, dict)
|
||||
|
||||
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
||||
try:
|
||||
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
|
||||
migrated_config_dict = ConfigMigrator.migrate(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
|
||||
except Exception as e:
|
||||
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
||||
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
|
||||
|
||||
# Attempt to load as a v4 config file
|
||||
try:
|
||||
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
|
||||
assert (
|
||||
config.schema_version == CONFIG_SCHEMA_VERSION
|
||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}"
|
||||
return config
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_config() -> InvokeAIAppConfig:
|
||||
"""Get the global singleton app config.
|
||||
|
||||
When first called, this function:
|
||||
- Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file.
|
||||
- Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint.
|
||||
- Sets the root dir, if provided via CLI args.
|
||||
- Logs in to HF if there is no valid token already.
|
||||
- Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers).
|
||||
- Reads and merges in settings from the config file if it exists, else writes out a default config file.
|
||||
|
||||
On subsequent calls, the object is returned from the cache.
|
||||
"""
|
||||
# This object includes environment variables, as parsed by pydantic-settings
|
||||
config = InvokeAIAppConfig()
|
||||
|
||||
args = InvokeAIArgs.args
|
||||
|
||||
# This flag serves as a proxy for whether the config was retrieved in the context of the full application or not.
|
||||
# If it is False, we should just return a default config and not set the root, log in to HF, etc.
|
||||
if not InvokeAIArgs.did_parse:
|
||||
return config
|
||||
|
||||
# Set CLI args
|
||||
if root := getattr(args, "root", None):
|
||||
config._root = Path(root)
|
||||
if config_file := getattr(args, "config_file", None):
|
||||
config._config_file = Path(config_file)
|
||||
|
||||
# Create the example config file, with some extra example values provided
|
||||
example_config = DefaultInvokeAIAppConfig()
|
||||
example_config.remote_api_tokens = [
|
||||
URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"),
|
||||
URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"),
|
||||
]
|
||||
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
|
||||
|
||||
# Copy all legacy configs - We know `__path__[0]` is correct here
|
||||
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
|
||||
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
|
||||
|
||||
if config.config_file_path.exists():
|
||||
config_from_file = load_and_migrate_config(config.config_file_path)
|
||||
config_from_file.write_file(config.config_file_path)
|
||||
# Clobbering here will overwrite any settings that were set via environment variables
|
||||
config.update_config(config_from_file, clobber=False)
|
||||
else:
|
||||
# We should never write env vars to the config file
|
||||
default_config = DefaultInvokeAIAppConfig()
|
||||
default_config.write_file(config.config_file_path, as_example=False)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
####################################################
|
||||
# VERSION MIGRATIONS
|
||||
####################################################
|
||||
|
||||
|
||||
@ConfigMigrator.register(from_version="3.0.0", to_version="4.0.0")
|
||||
def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Migrate a v3 config dictionary to a current config object.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v3 config file.
|
||||
|
||||
Returns:
|
||||
A dictionary of settings from a 4.0.0 config file.
|
||||
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for _category_name, category_dict in config_dict["InvokeAI"].items():
|
||||
for k, v in category_dict.items():
|
||||
# `outdir` was renamed to `outputs_dir` in v4
|
||||
if k == "outdir":
|
||||
parsed_config_dict["outputs_dir"] = v
|
||||
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
||||
if k == "max_cache_size" and "ram" not in category_dict:
|
||||
parsed_config_dict["ram"] = v
|
||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||
parsed_config_dict["vram"] = v
|
||||
# autocast was removed in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
if k == "conf_path":
|
||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
||||
if k == "legacy_conf_dir":
|
||||
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
|
||||
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
|
||||
# If if the incoming config has the default value, skip
|
||||
continue
|
||||
elif Path(v).name == "stable-diffusion":
|
||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
|
||||
else:
|
||||
# Else we do not attempt to migrate this setting
|
||||
parsed_config_dict["legacy_conf_dir"] = v
|
||||
elif k in InvokeAIAppConfig.model_fields:
|
||||
# skip unknown fields
|
||||
parsed_config_dict[k] = v
|
||||
return parsed_config_dict
|
||||
|
||||
|
||||
@ConfigMigrator.register(from_version="4.0.0", to_version="4.0.1")
|
||||
def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Migrate v4.0.0 config dictionary to v4.0.1.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v4.0.0 config file.
|
||||
|
||||
Returns:
|
||||
A dictionary of settings from a v4.0.1 config file
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for k, v in config_dict.items():
|
||||
# autocast was removed from precision in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
else:
|
||||
parsed_config_dict[k] = v
|
||||
return parsed_config_dict
|
||||
|
@ -4,12 +4,22 @@
|
||||
Utility class for migrating among versions of the InvokeAI app config schema.
|
||||
"""
|
||||
|
||||
import locale
|
||||
import shutil
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, Callable, List, TypeAlias
|
||||
from functools import lru_cache
|
||||
from pathlib import Path
|
||||
from typing import Callable, List, TypeAlias
|
||||
|
||||
import yaml
|
||||
from packaging.version import Version
|
||||
|
||||
AppConfigDict: TypeAlias = dict[str, Any]
|
||||
import invokeai.configs as model_configs
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
from .config_default import CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, URLRegexTokenPair
|
||||
from .migrations import AppConfigDict, Migrations, MigrationsBase
|
||||
|
||||
MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
|
||||
|
||||
|
||||
@ -25,22 +35,23 @@ class MigrationEntry:
|
||||
class ConfigMigrator:
|
||||
"""This class allows migrators to register their input and output versions."""
|
||||
|
||||
_migrations: List[MigrationEntry] = []
|
||||
def __init__(self, migrations: type[MigrationsBase]) -> None:
|
||||
self._migrations: List[MigrationEntry] = []
|
||||
migrations.load(self)
|
||||
|
||||
@classmethod
|
||||
def register(
|
||||
cls,
|
||||
self,
|
||||
from_version: str,
|
||||
to_version: str,
|
||||
) -> Callable[[MigrationFunction], MigrationFunction]:
|
||||
"""Define a decorator which registers the migration between two versions."""
|
||||
|
||||
def decorator(function: MigrationFunction) -> MigrationFunction:
|
||||
if any(from_version == m.from_version for m in cls._migrations):
|
||||
if any((from_version == m.from_version) or (to_version == m.to_version) for m in self._migrations):
|
||||
raise ValueError(
|
||||
f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered."
|
||||
)
|
||||
cls._migrations.append(
|
||||
self._migrations.append(
|
||||
MigrationEntry(from_version=Version(from_version), to_version=Version(to_version), function=function)
|
||||
)
|
||||
return function
|
||||
@ -57,8 +68,7 @@ class ConfigMigrator:
|
||||
)
|
||||
current_version = m.to_version
|
||||
|
||||
@classmethod
|
||||
def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict:
|
||||
def run_migrations(self, config_dict: AppConfigDict) -> AppConfigDict:
|
||||
"""
|
||||
Use the registered migration steps to bring config up to latest version.
|
||||
|
||||
@ -72,8 +82,8 @@ class ConfigMigrator:
|
||||
"""
|
||||
# Sort migrations by version number and raise a ValueError if
|
||||
# any version range overlaps are detected.
|
||||
sorted_migrations = sorted(cls._migrations, key=lambda x: x.from_version)
|
||||
cls._check_for_discontinuities(sorted_migrations)
|
||||
sorted_migrations = sorted(self._migrations, key=lambda x: x.from_version)
|
||||
self._check_for_discontinuities(sorted_migrations)
|
||||
|
||||
if "InvokeAI" in config_dict:
|
||||
version = Version("3.0.0")
|
||||
@ -87,3 +97,92 @@ class ConfigMigrator:
|
||||
|
||||
config_dict["schema_version"] = str(version)
|
||||
return config_dict
|
||||
|
||||
|
||||
@lru_cache(maxsize=1)
|
||||
def get_config() -> InvokeAIAppConfig:
|
||||
"""Get the global singleton app config.
|
||||
|
||||
When first called, this function:
|
||||
- Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file.
|
||||
- Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint.
|
||||
- Sets the root dir, if provided via CLI args.
|
||||
- Logs in to HF if there is no valid token already.
|
||||
- Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers).
|
||||
- Reads and merges in settings from the config file if it exists, else writes out a default config file.
|
||||
|
||||
On subsequent calls, the object is returned from the cache.
|
||||
"""
|
||||
# This object includes environment variables, as parsed by pydantic-settings
|
||||
config = InvokeAIAppConfig()
|
||||
|
||||
args = InvokeAIArgs.args
|
||||
|
||||
# This flag serves as a proxy for whether the config was retrieved in the context of the full application or not.
|
||||
# If it is False, we should just return a default config and not set the root, log in to HF, etc.
|
||||
if not InvokeAIArgs.did_parse:
|
||||
return config
|
||||
|
||||
# Set CLI args
|
||||
if root := getattr(args, "root", None):
|
||||
config._root = Path(root)
|
||||
if config_file := getattr(args, "config_file", None):
|
||||
config._config_file = Path(config_file)
|
||||
|
||||
# Create the example config file, with some extra example values provided
|
||||
example_config = DefaultInvokeAIAppConfig()
|
||||
example_config.remote_api_tokens = [
|
||||
URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"),
|
||||
URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"),
|
||||
]
|
||||
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
|
||||
|
||||
# Copy all legacy configs - We know `__path__[0]` is correct here
|
||||
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
|
||||
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
|
||||
|
||||
if config.config_file_path.exists():
|
||||
config_from_file = load_and_migrate_config(config.config_file_path)
|
||||
config_from_file.write_file(config.config_file_path)
|
||||
# Clobbering here will overwrite any settings that were set via environment variables
|
||||
config.update_config(config_from_file, clobber=False)
|
||||
else:
|
||||
# We should never write env vars to the config file
|
||||
default_config = DefaultInvokeAIAppConfig()
|
||||
default_config.write_file(config.config_file_path, as_example=False)
|
||||
|
||||
return config
|
||||
|
||||
|
||||
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
|
||||
"""Load and migrate a config file to the latest version.
|
||||
|
||||
Args:
|
||||
config_path: Path to the config file.
|
||||
|
||||
Returns:
|
||||
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
|
||||
"""
|
||||
assert config_path.suffix == ".yaml"
|
||||
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
|
||||
loaded_config_dict = yaml.safe_load(file)
|
||||
|
||||
assert isinstance(loaded_config_dict, dict)
|
||||
|
||||
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
|
||||
try:
|
||||
migrator = ConfigMigrator(Migrations)
|
||||
migrated_config_dict = migrator.run_migrations(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
|
||||
except Exception as e:
|
||||
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
|
||||
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
|
||||
|
||||
# Attempt to load as a v4 config file
|
||||
try:
|
||||
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
|
||||
assert (
|
||||
config.schema_version == CONFIG_SCHEMA_VERSION
|
||||
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}"
|
||||
return config
|
||||
except Exception as e:
|
||||
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
|
||||
|
108
invokeai/app/services/config/migrations.py
Normal file
108
invokeai/app/services/config/migrations.py
Normal file
@ -0,0 +1,108 @@
|
||||
# Copyright 2024 Lincoln D. Stein and the InvokeAI Development Team
|
||||
|
||||
"""
|
||||
Schema migrations to perform on an InvokeAIAppConfig object.
|
||||
|
||||
The Migrations class defined in this module defines a series of
|
||||
schema version migration steps for the InvokeAIConfig object.
|
||||
|
||||
To define a new migration, add a migration function to
|
||||
Migrations.load_migrations() following the existing examples.
|
||||
"""
|
||||
|
||||
from abc import ABC, abstractmethod
|
||||
from pathlib import Path
|
||||
from typing import TYPE_CHECKING, Any, TypeAlias
|
||||
|
||||
from .config_default import InvokeAIAppConfig
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from .config_migrate import ConfigMigrator
|
||||
|
||||
AppConfigDict: TypeAlias = dict[str, Any]
|
||||
|
||||
|
||||
class MigrationsBase(ABC):
|
||||
"""Define the config file migration steps to apply, abstract base class."""
|
||||
|
||||
@classmethod
|
||||
@abstractmethod
|
||||
def load(cls, migrator: "ConfigMigrator") -> None:
|
||||
"""Use the provided migrator to register the configuration migrations to be run."""
|
||||
|
||||
|
||||
class Migrations(MigrationsBase):
|
||||
"""Configuration migration steps to apply."""
|
||||
|
||||
@classmethod
|
||||
def load(cls, migrator: "ConfigMigrator") -> None:
|
||||
"""Define migrations to perform."""
|
||||
|
||||
##################
|
||||
# 3.0.0 -> 4.0.0 #
|
||||
##################
|
||||
@migrator.register(from_version="3.0.0", to_version="4.0.0")
|
||||
def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Migrate a v3 config dictionary to a current config object.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v3 config file.
|
||||
|
||||
Returns:
|
||||
A dictionary of settings from a 4.0.0 config file.
|
||||
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for _category_name, category_dict in config_dict["InvokeAI"].items():
|
||||
for k, v in category_dict.items():
|
||||
# `outdir` was renamed to `outputs_dir` in v4
|
||||
if k == "outdir":
|
||||
parsed_config_dict["outputs_dir"] = v
|
||||
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
|
||||
if k == "max_cache_size" and "ram" not in category_dict:
|
||||
parsed_config_dict["ram"] = v
|
||||
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
|
||||
if k == "max_vram_cache_size" and "vram" not in category_dict:
|
||||
parsed_config_dict["vram"] = v
|
||||
# autocast was removed in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
if k == "conf_path":
|
||||
parsed_config_dict["legacy_models_yaml_path"] = v
|
||||
if k == "legacy_conf_dir":
|
||||
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
|
||||
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
|
||||
# If if the incoming config has the default value, skip
|
||||
continue
|
||||
elif Path(v).name == "stable-diffusion":
|
||||
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
|
||||
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
|
||||
else:
|
||||
# Else we do not attempt to migrate this setting
|
||||
parsed_config_dict["legacy_conf_dir"] = v
|
||||
elif k in InvokeAIAppConfig.model_fields:
|
||||
# skip unknown fields
|
||||
parsed_config_dict[k] = v
|
||||
return parsed_config_dict
|
||||
|
||||
##################
|
||||
# 4.0.0 -> 4.0.1 #
|
||||
##################
|
||||
@migrator.register(from_version="4.0.0", to_version="4.0.1")
|
||||
def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
"""Migrate v4.0.0 config dictionary to v4.0.1.
|
||||
|
||||
Args:
|
||||
config_dict: A dictionary of settings from a v4.0.0 config file.
|
||||
|
||||
Returns:
|
||||
A dictionary of settings from a v4.0.1 config file
|
||||
"""
|
||||
parsed_config_dict: dict[str, Any] = {}
|
||||
for k, v in config_dict.items():
|
||||
# autocast was removed from precision in v4.0.1
|
||||
if k == "precision" and v == "autocast":
|
||||
parsed_config_dict["precision"] = "auto"
|
||||
else:
|
||||
parsed_config_dict[k] = v
|
||||
return parsed_config_dict
|
@ -9,7 +9,7 @@ from torch import Tensor
|
||||
from invokeai.app.invocations.constants import IMAGE_MODES
|
||||
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
|
||||
from invokeai.app.services.boards.boards_common import BoardDTO
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
|
||||
from invokeai.app.services.images.images_common import ImageDTO
|
||||
from invokeai.app.services.invocation_services import InvocationServices
|
||||
|
@ -1,6 +1,6 @@
|
||||
from logging import Logger
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
|
||||
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
|
||||
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1
|
||||
|
@ -1,7 +1,7 @@
|
||||
import sqlite3
|
||||
from pathlib import Path
|
||||
|
||||
from invokeai.app.services.config.config_default import InvokeAIAppConfig
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration
|
||||
|
||||
|
||||
|
@ -9,7 +9,7 @@ from einops import repeat
|
||||
from PIL import Image
|
||||
from torchvision.transforms import Compose
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
|
||||
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize
|
||||
|
@ -5,7 +5,7 @@
|
||||
import numpy as np
|
||||
import onnxruntime as ort
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
@ -6,7 +6,7 @@ import torch
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.app.util.download_with_progress import download_with_progress_bar
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
|
||||
|
@ -9,7 +9,7 @@ import numpy as np
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
|
||||
|
||||
class PatchMatch:
|
||||
|
@ -10,7 +10,7 @@ from imwatermark import WatermarkEncoder
|
||||
from PIL import Image
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
|
||||
config = get_config()
|
||||
|
||||
|
@ -12,7 +12,7 @@ from PIL import Image
|
||||
from transformers import AutoFeatureExtractor
|
||||
|
||||
import invokeai.backend.util.logging as logger
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.backend.util.devices import TorchDevice
|
||||
from invokeai.backend.util.silence_warnings import SilenceWarnings
|
||||
|
||||
|
@ -20,7 +20,7 @@ from diffusers.utils.import_utils import is_xformers_available
|
||||
from pydantic import Field
|
||||
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
|
||||
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
|
||||
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData
|
||||
|
@ -6,7 +6,7 @@ from typing import Any, Callable, Optional, Union
|
||||
import torch
|
||||
from typing_extensions import TypeAlias
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
|
||||
IPAdapterData,
|
||||
Range,
|
||||
|
@ -3,7 +3,7 @@ from typing import Dict, Literal, Optional, Union
|
||||
import torch
|
||||
from deprecated import deprecated
|
||||
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import get_config
|
||||
|
||||
# legacy APIs
|
||||
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]
|
||||
|
@ -180,8 +180,7 @@ import urllib.parse
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from invokeai.app.services.config import InvokeAIAppConfig
|
||||
from invokeai.app.services.config.config_default import get_config
|
||||
from invokeai.app.services.config import InvokeAIAppConfig, get_config
|
||||
|
||||
try:
|
||||
import syslog
|
||||
|
@ -12,10 +12,9 @@ from invokeai.app.services.config.config_default import (
|
||||
CONFIG_SCHEMA_VERSION,
|
||||
DefaultInvokeAIAppConfig,
|
||||
InvokeAIAppConfig,
|
||||
get_config,
|
||||
load_and_migrate_config,
|
||||
)
|
||||
from invokeai.app.services.config.config_migrate import ConfigMigrator
|
||||
from invokeai.app.services.config.config_migrate import ConfigMigrator, get_config, load_and_migrate_config
|
||||
from invokeai.app.services.config.migrations import Migrations, MigrationsBase
|
||||
from invokeai.app.services.shared.graph import Graph
|
||||
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
|
||||
|
||||
@ -72,6 +71,68 @@ i like turtles
|
||||
"""
|
||||
|
||||
|
||||
class GoodMigrations(MigrationsBase):
|
||||
@classmethod
|
||||
def load(cls, migrator: ConfigMigrator) -> None:
|
||||
"""Define migrations to perform."""
|
||||
|
||||
@migrator.register(from_version="3.0.0", to_version="10.0.0")
|
||||
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
@migrator.register(from_version="10.0.0", to_version="10.0.1")
|
||||
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
@migrator.register(from_version="10.0.1", to_version="10.0.2")
|
||||
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
|
||||
class BadMigrations1(MigrationsBase):
|
||||
"""This one fails because there is no path from 10.0.1 to 10.0.2"""
|
||||
|
||||
@classmethod
|
||||
def load(cls, migrator: ConfigMigrator) -> None:
|
||||
"""Define migrations to perform."""
|
||||
|
||||
@migrator.register(from_version="3.0.0", to_version="10.0.0")
|
||||
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
@migrator.register(from_version="10.0.0", to_version="10.0.1")
|
||||
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
@migrator.register(from_version="10.0.2", to_version="10.0.3")
|
||||
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
|
||||
class BadMigrations2(MigrationsBase):
|
||||
"""This one fails because the path to 10.0.2 is registered twice"""
|
||||
|
||||
@classmethod
|
||||
def load(cls, migrator: ConfigMigrator) -> None:
|
||||
"""Define migrations to perform."""
|
||||
|
||||
@migrator.register(from_version="3.0.0", to_version="10.0.0")
|
||||
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
@migrator.register(from_version="10.0.0", to_version="10.0.1")
|
||||
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
@migrator.register(from_version="10.0.1", to_version="10.0.2")
|
||||
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
@migrator.register(from_version="10.0.0", to_version="10.0.2")
|
||||
def migration_4(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
|
||||
"""This may be overkill since the current tests don't need the root dir to exist"""
|
||||
@ -291,33 +352,26 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch
|
||||
|
||||
|
||||
def test_migration_check() -> None:
|
||||
new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"})
|
||||
# Test the default set of migrations
|
||||
migrator = ConfigMigrator(Migrations)
|
||||
new_config = migrator.run_migrations({"schema_version": "4.0.0"})
|
||||
assert new_config is not None
|
||||
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION
|
||||
|
||||
# Does this execute at compile time or run time?
|
||||
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION, to_version=CONFIG_SCHEMA_VERSION + ".1")
|
||||
def ok_migration(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
# Test a custom set of migrations
|
||||
migrator = ConfigMigrator(GoodMigrations)
|
||||
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
|
||||
assert new_config["schema_version"] == "10.0.2"
|
||||
|
||||
new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"})
|
||||
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".1"
|
||||
|
||||
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".2", to_version=CONFIG_SCHEMA_VERSION + ".3")
|
||||
def bad_migration(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
# Because there is no version for "*.1" => "*.2", this should fail.
|
||||
# Test a migration that should fail validation
|
||||
migrator = ConfigMigrator(BadMigrations1)
|
||||
with pytest.raises(ValueError):
|
||||
ConfigMigrator.migrate({"schema_version": "4.0.0"})
|
||||
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
|
||||
|
||||
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".1", to_version=CONFIG_SCHEMA_VERSION + ".2")
|
||||
def good_migration(config_dict: dict[str, Any]) -> dict[str, Any]:
|
||||
return config_dict
|
||||
|
||||
# should work now, because there is a continuous path to *.3
|
||||
new_config = ConfigMigrator.migrate(new_config)
|
||||
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".3"
|
||||
# Test another bad migration
|
||||
migrator = ConfigMigrator(BadMigrations2)
|
||||
with pytest.raises(ValueError):
|
||||
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
|
||||
|
||||
|
||||
@contextmanager
|
||||
|
Loading…
Reference in New Issue
Block a user