make config migrator into an instance; refactor location of get_config()

This commit is contained in:
Lincoln Stein 2024-05-02 23:45:34 -04:00
parent d5aee87684
commit a48abfacf4
21 changed files with 314 additions and 219 deletions

View File

@ -26,7 +26,7 @@ import invokeai.backend.util.hotfixes # noqa: F401 (monkeypatching on import)
import invokeai.frontend.web as web_dir import invokeai.frontend.web as web_dir
from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles from invokeai.app.api.no_cache_staticfiles import NoCacheStaticFiles
from invokeai.app.invocations.model import ModelIdentifierField from invokeai.app.invocations.model import ModelIdentifierField
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config import get_config
from invokeai.app.services.session_processor.session_processor_common import ProgressImage from invokeai.app.services.session_processor.session_processor_common import ProgressImage
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice

View File

@ -3,7 +3,7 @@ import sys
from importlib.util import module_from_spec, spec_from_file_location from importlib.util import module_from_spec, spec_from_file_location
from pathlib import Path from pathlib import Path
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config import get_config
custom_nodes_path = Path(get_config().custom_nodes_path) custom_nodes_path = Path(get_config().custom_nodes_path)
custom_nodes_path.mkdir(parents=True, exist_ok=True) custom_nodes_path.mkdir(parents=True, exist_ok=True)

View File

@ -33,7 +33,7 @@ from invokeai.app.invocations.fields import (
FieldKind, FieldKind,
Input, Input,
) )
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config import get_config
from invokeai.app.services.shared.invocation_context import InvocationContext from invokeai.app.services.shared.invocation_context import InvocationContext
from invokeai.app.util.metaenum import MetaEnum from invokeai.app.util.metaenum import MetaEnum
from invokeai.app.util.misc import uuid_string from invokeai.app.util.misc import uuid_string

View File

@ -2,6 +2,7 @@
from invokeai.app.services.config.config_common import PagingArgumentParser from invokeai.app.services.config.config_common import PagingArgumentParser
from .config_default import InvokeAIAppConfig, get_config from .config_default import InvokeAIAppConfig
from .config_migrate import get_config
__all__ = ["InvokeAIAppConfig", "get_config", "PagingArgumentParser"] __all__ = ["InvokeAIAppConfig", "get_config", "PagingArgumentParser"]

View File

@ -3,11 +3,8 @@
from __future__ import annotations from __future__ import annotations
import locale
import os import os
import re import re
import shutil
from functools import lru_cache
from pathlib import Path from pathlib import Path
from typing import Any, Literal, Optional from typing import Any, Literal, Optional
@ -16,11 +13,7 @@ import yaml
from pydantic import BaseModel, Field, PrivateAttr, field_validator from pydantic import BaseModel, Field, PrivateAttr, field_validator
from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict from pydantic_settings import BaseSettings, PydanticBaseSettingsSource, SettingsConfigDict
import invokeai.configs as model_configs
from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS from invokeai.backend.model_hash.model_hash import HASHING_ALGORITHMS
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
from .config_migrate import ConfigMigrator
INIT_FILE = Path("invokeai.yaml") INIT_FILE = Path("invokeai.yaml")
DB_FILE = Path("invokeai.db") DB_FILE = Path("invokeai.db")
@ -348,162 +341,3 @@ class DefaultInvokeAIAppConfig(InvokeAIAppConfig):
file_secret_settings: PydanticBaseSettingsSource, file_secret_settings: PydanticBaseSettingsSource,
) -> tuple[PydanticBaseSettingsSource, ...]: ) -> tuple[PydanticBaseSettingsSource, ...]:
return (init_settings,) return (init_settings,)
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
Args:
config_path: Path to the config file.
Returns:
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
"""
assert config_path.suffix == ".yaml"
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
loaded_config_dict = yaml.safe_load(file)
assert isinstance(loaded_config_dict, dict)
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try:
# loaded_config_dict could be the wrong shape, but we will catch all exceptions below
migrated_config_dict = ConfigMigrator.migrate(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
# Attempt to load as a v4 config file
try:
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e
@lru_cache(maxsize=1)
def get_config() -> InvokeAIAppConfig:
"""Get the global singleton app config.
When first called, this function:
- Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file.
- Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint.
- Sets the root dir, if provided via CLI args.
- Logs in to HF if there is no valid token already.
- Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers).
- Reads and merges in settings from the config file if it exists, else writes out a default config file.
On subsequent calls, the object is returned from the cache.
"""
# This object includes environment variables, as parsed by pydantic-settings
config = InvokeAIAppConfig()
args = InvokeAIArgs.args
# This flag serves as a proxy for whether the config was retrieved in the context of the full application or not.
# If it is False, we should just return a default config and not set the root, log in to HF, etc.
if not InvokeAIArgs.did_parse:
return config
# Set CLI args
if root := getattr(args, "root", None):
config._root = Path(root)
if config_file := getattr(args, "config_file", None):
config._config_file = Path(config_file)
# Create the example config file, with some extra example values provided
example_config = DefaultInvokeAIAppConfig()
example_config.remote_api_tokens = [
URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"),
URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"),
]
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
# Copy all legacy configs - We know `__path__[0]` is correct here
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
if config.config_file_path.exists():
config_from_file = load_and_migrate_config(config.config_file_path)
config_from_file.write_file(config.config_file_path)
# Clobbering here will overwrite any settings that were set via environment variables
config.update_config(config_from_file, clobber=False)
else:
# We should never write env vars to the config file
default_config = DefaultInvokeAIAppConfig()
default_config.write_file(config.config_file_path, as_example=False)
return config
####################################################
# VERSION MIGRATIONS
####################################################
@ConfigMigrator.register(from_version="3.0.0", to_version="4.0.0")
def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate a v3 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v3 config file.
Returns:
A dictionary of settings from a 4.0.0 config file.
"""
parsed_config_dict: dict[str, Any] = {}
for _category_name, category_dict in config_dict["InvokeAI"].items():
for k, v in category_dict.items():
# `outdir` was renamed to `outputs_dir` in v4
if k == "outdir":
parsed_config_dict["outputs_dir"] = v
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
# If if the incoming config has the default value, skip
continue
elif Path(v).name == "stable-diffusion":
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
else:
# Else we do not attempt to migrate this setting
parsed_config_dict["legacy_conf_dir"] = v
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
parsed_config_dict[k] = v
return parsed_config_dict
@ConfigMigrator.register(from_version="4.0.0", to_version="4.0.1")
def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate v4.0.0 config dictionary to v4.0.1.
Args:
config_dict: A dictionary of settings from a v4.0.0 config file.
Returns:
A dictionary of settings from a v4.0.1 config file
"""
parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
else:
parsed_config_dict[k] = v
return parsed_config_dict

View File

@ -4,12 +4,22 @@
Utility class for migrating among versions of the InvokeAI app config schema. Utility class for migrating among versions of the InvokeAI app config schema.
""" """
import locale
import shutil
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, List, TypeAlias from functools import lru_cache
from pathlib import Path
from typing import Callable, List, TypeAlias
import yaml
from packaging.version import Version from packaging.version import Version
AppConfigDict: TypeAlias = dict[str, Any] import invokeai.configs as model_configs
from invokeai.frontend.cli.arg_parser import InvokeAIArgs
from .config_default import CONFIG_SCHEMA_VERSION, DefaultInvokeAIAppConfig, InvokeAIAppConfig, URLRegexTokenPair
from .migrations import AppConfigDict, Migrations, MigrationsBase
MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict] MigrationFunction: TypeAlias = Callable[[AppConfigDict], AppConfigDict]
@ -25,22 +35,23 @@ class MigrationEntry:
class ConfigMigrator: class ConfigMigrator:
"""This class allows migrators to register their input and output versions.""" """This class allows migrators to register their input and output versions."""
_migrations: List[MigrationEntry] = [] def __init__(self, migrations: type[MigrationsBase]) -> None:
self._migrations: List[MigrationEntry] = []
migrations.load(self)
@classmethod
def register( def register(
cls, self,
from_version: str, from_version: str,
to_version: str, to_version: str,
) -> Callable[[MigrationFunction], MigrationFunction]: ) -> Callable[[MigrationFunction], MigrationFunction]:
"""Define a decorator which registers the migration between two versions.""" """Define a decorator which registers the migration between two versions."""
def decorator(function: MigrationFunction) -> MigrationFunction: def decorator(function: MigrationFunction) -> MigrationFunction:
if any(from_version == m.from_version for m in cls._migrations): if any((from_version == m.from_version) or (to_version == m.to_version) for m in self._migrations):
raise ValueError( raise ValueError(
f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered." f"function {function.__name__} is trying to register a migration for version {str(from_version)}, but this migration has already been registered."
) )
cls._migrations.append( self._migrations.append(
MigrationEntry(from_version=Version(from_version), to_version=Version(to_version), function=function) MigrationEntry(from_version=Version(from_version), to_version=Version(to_version), function=function)
) )
return function return function
@ -57,8 +68,7 @@ class ConfigMigrator:
) )
current_version = m.to_version current_version = m.to_version
@classmethod def run_migrations(self, config_dict: AppConfigDict) -> AppConfigDict:
def migrate(cls, config_dict: AppConfigDict) -> AppConfigDict:
""" """
Use the registered migration steps to bring config up to latest version. Use the registered migration steps to bring config up to latest version.
@ -72,8 +82,8 @@ class ConfigMigrator:
""" """
# Sort migrations by version number and raise a ValueError if # Sort migrations by version number and raise a ValueError if
# any version range overlaps are detected. # any version range overlaps are detected.
sorted_migrations = sorted(cls._migrations, key=lambda x: x.from_version) sorted_migrations = sorted(self._migrations, key=lambda x: x.from_version)
cls._check_for_discontinuities(sorted_migrations) self._check_for_discontinuities(sorted_migrations)
if "InvokeAI" in config_dict: if "InvokeAI" in config_dict:
version = Version("3.0.0") version = Version("3.0.0")
@ -87,3 +97,92 @@ class ConfigMigrator:
config_dict["schema_version"] = str(version) config_dict["schema_version"] = str(version)
return config_dict return config_dict
@lru_cache(maxsize=1)
def get_config() -> InvokeAIAppConfig:
"""Get the global singleton app config.
When first called, this function:
- Creates a config object. `pydantic-settings` handles merging of settings from environment variables, but not the init file.
- Retrieves any provided CLI args from the InvokeAIArgs class. It does not _parse_ the CLI args; that is done in the main entrypoint.
- Sets the root dir, if provided via CLI args.
- Logs in to HF if there is no valid token already.
- Copies all legacy configs to the legacy conf dir (needed for conversion from ckpt to diffusers).
- Reads and merges in settings from the config file if it exists, else writes out a default config file.
On subsequent calls, the object is returned from the cache.
"""
# This object includes environment variables, as parsed by pydantic-settings
config = InvokeAIAppConfig()
args = InvokeAIArgs.args
# This flag serves as a proxy for whether the config was retrieved in the context of the full application or not.
# If it is False, we should just return a default config and not set the root, log in to HF, etc.
if not InvokeAIArgs.did_parse:
return config
# Set CLI args
if root := getattr(args, "root", None):
config._root = Path(root)
if config_file := getattr(args, "config_file", None):
config._config_file = Path(config_file)
# Create the example config file, with some extra example values provided
example_config = DefaultInvokeAIAppConfig()
example_config.remote_api_tokens = [
URLRegexTokenPair(url_regex="cool-models.com", token="my_secret_token"),
URLRegexTokenPair(url_regex="nifty-models.com", token="some_other_token"),
]
example_config.write_file(config.config_file_path.with_suffix(".example.yaml"), as_example=True)
# Copy all legacy configs - We know `__path__[0]` is correct here
configs_src = Path(model_configs.__path__[0]) # pyright: ignore [reportUnknownMemberType, reportUnknownArgumentType, reportAttributeAccessIssue]
shutil.copytree(configs_src, config.legacy_conf_path, dirs_exist_ok=True)
if config.config_file_path.exists():
config_from_file = load_and_migrate_config(config.config_file_path)
config_from_file.write_file(config.config_file_path)
# Clobbering here will overwrite any settings that were set via environment variables
config.update_config(config_from_file, clobber=False)
else:
# We should never write env vars to the config file
default_config = DefaultInvokeAIAppConfig()
default_config.write_file(config.config_file_path, as_example=False)
return config
def load_and_migrate_config(config_path: Path) -> InvokeAIAppConfig:
"""Load and migrate a config file to the latest version.
Args:
config_path: Path to the config file.
Returns:
An instance of `InvokeAIAppConfig` with the loaded and migrated settings.
"""
assert config_path.suffix == ".yaml"
with open(config_path, "rt", encoding=locale.getpreferredencoding()) as file:
loaded_config_dict = yaml.safe_load(file)
assert isinstance(loaded_config_dict, dict)
shutil.copy(config_path, config_path.with_suffix(".yaml.bak"))
try:
migrator = ConfigMigrator(Migrations)
migrated_config_dict = migrator.run_migrations(loaded_config_dict) # pyright: ignore [reportUnknownArgumentType]
except Exception as e:
shutil.copy(config_path.with_suffix(".yaml.bak"), config_path)
raise RuntimeError(f"Failed to load and migrate config file {config_path}: {e}") from e
# Attempt to load as a v4 config file
try:
config = InvokeAIAppConfig.model_validate(migrated_config_dict)
assert (
config.schema_version == CONFIG_SCHEMA_VERSION
), f"Invalid schema version, expected {CONFIG_SCHEMA_VERSION} but got {config.schema_version}"
return config
except Exception as e:
raise RuntimeError(f"Failed to load config file {config_path}: {e}") from e

View File

@ -0,0 +1,108 @@
# Copyright 2024 Lincoln D. Stein and the InvokeAI Development Team
"""
Schema migrations to perform on an InvokeAIAppConfig object.
The Migrations class defined in this module defines a series of
schema version migration steps for the InvokeAIConfig object.
To define a new migration, add a migration function to
Migrations.load_migrations() following the existing examples.
"""
from abc import ABC, abstractmethod
from pathlib import Path
from typing import TYPE_CHECKING, Any, TypeAlias
from .config_default import InvokeAIAppConfig
if TYPE_CHECKING:
from .config_migrate import ConfigMigrator
AppConfigDict: TypeAlias = dict[str, Any]
class MigrationsBase(ABC):
"""Define the config file migration steps to apply, abstract base class."""
@classmethod
@abstractmethod
def load(cls, migrator: "ConfigMigrator") -> None:
"""Use the provided migrator to register the configuration migrations to be run."""
class Migrations(MigrationsBase):
"""Configuration migration steps to apply."""
@classmethod
def load(cls, migrator: "ConfigMigrator") -> None:
"""Define migrations to perform."""
##################
# 3.0.0 -> 4.0.0 #
##################
@migrator.register(from_version="3.0.0", to_version="4.0.0")
def migrate_1(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate a v3 config dictionary to a current config object.
Args:
config_dict: A dictionary of settings from a v3 config file.
Returns:
A dictionary of settings from a 4.0.0 config file.
"""
parsed_config_dict: dict[str, Any] = {}
for _category_name, category_dict in config_dict["InvokeAI"].items():
for k, v in category_dict.items():
# `outdir` was renamed to `outputs_dir` in v4
if k == "outdir":
parsed_config_dict["outputs_dir"] = v
# `max_cache_size` was renamed to `ram` some time in v3, but both names were used
if k == "max_cache_size" and "ram" not in category_dict:
parsed_config_dict["ram"] = v
# `max_vram_cache_size` was renamed to `vram` some time in v3, but both names were used
if k == "max_vram_cache_size" and "vram" not in category_dict:
parsed_config_dict["vram"] = v
# autocast was removed in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
if k == "conf_path":
parsed_config_dict["legacy_models_yaml_path"] = v
if k == "legacy_conf_dir":
# The old default for this was "configs/stable-diffusion" ("configs\stable-diffusion" on Windows).
if v == "configs/stable-diffusion" or v == "configs\\stable-diffusion":
# If if the incoming config has the default value, skip
continue
elif Path(v).name == "stable-diffusion":
# Else if the path ends in "stable-diffusion", we assume the parent is the new correct path.
parsed_config_dict["legacy_conf_dir"] = str(Path(v).parent)
else:
# Else we do not attempt to migrate this setting
parsed_config_dict["legacy_conf_dir"] = v
elif k in InvokeAIAppConfig.model_fields:
# skip unknown fields
parsed_config_dict[k] = v
return parsed_config_dict
##################
# 4.0.0 -> 4.0.1 #
##################
@migrator.register(from_version="4.0.0", to_version="4.0.1")
def migrate_2(config_dict: dict[str, Any]) -> dict[str, Any]:
"""Migrate v4.0.0 config dictionary to v4.0.1.
Args:
config_dict: A dictionary of settings from a v4.0.0 config file.
Returns:
A dictionary of settings from a v4.0.1 config file
"""
parsed_config_dict: dict[str, Any] = {}
for k, v in config_dict.items():
# autocast was removed from precision in v4.0.1
if k == "precision" and v == "autocast":
parsed_config_dict["precision"] = "auto"
else:
parsed_config_dict[k] = v
return parsed_config_dict

View File

@ -9,7 +9,7 @@ from torch import Tensor
from invokeai.app.invocations.constants import IMAGE_MODES from invokeai.app.invocations.constants import IMAGE_MODES
from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata from invokeai.app.invocations.fields import MetadataField, WithBoard, WithMetadata
from invokeai.app.services.boards.boards_common import BoardDTO from invokeai.app.services.boards.boards_common import BoardDTO
from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin from invokeai.app.services.image_records.image_records_common import ImageCategory, ResourceOrigin
from invokeai.app.services.images.images_common import ImageDTO from invokeai.app.services.images.images_common import ImageDTO
from invokeai.app.services.invocation_services import InvocationServices from invokeai.app.services.invocation_services import InvocationServices

View File

@ -1,6 +1,6 @@
from logging import Logger from logging import Logger
from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase from invokeai.app.services.image_files.image_files_base import ImageFileStorageBase
from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase from invokeai.app.services.shared.sqlite.sqlite_database import SqliteDatabase
from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1 from invokeai.app.services.shared.sqlite_migrator.migrations.migration_1 import build_migration_1

View File

@ -1,7 +1,7 @@
import sqlite3 import sqlite3
from pathlib import Path from pathlib import Path
from invokeai.app.services.config.config_default import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig
from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration from invokeai.app.services.shared.sqlite_migrator.sqlite_migrator_common import Migration

View File

@ -9,7 +9,7 @@ from einops import repeat
from PIL import Image from PIL import Image
from torchvision.transforms import Compose from torchvision.transforms import Compose
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2 from invokeai.backend.image_util.depth_anything.model.dpt import DPT_DINOv2
from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize from invokeai.backend.image_util.depth_anything.utilities.util import NormalizeImage, PrepareForNet, Resize

View File

@ -5,7 +5,7 @@
import numpy as np import numpy as np
import onnxruntime as ort import onnxruntime as ort
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice

View File

@ -6,7 +6,7 @@ import torch
from PIL import Image from PIL import Image
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config import get_config
from invokeai.app.util.download_with_progress import download_with_progress_bar from invokeai.app.util.download_with_progress import download_with_progress_bar
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice

View File

@ -9,7 +9,7 @@ import numpy as np
from PIL import Image from PIL import Image
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config import get_config
class PatchMatch: class PatchMatch:

View File

@ -10,7 +10,7 @@ from imwatermark import WatermarkEncoder
from PIL import Image from PIL import Image
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config import get_config
config = get_config() config = get_config()

View File

@ -12,7 +12,7 @@ from PIL import Image
from transformers import AutoFeatureExtractor from transformers import AutoFeatureExtractor
import invokeai.backend.util.logging as logger import invokeai.backend.util.logging as logger
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config import get_config
from invokeai.backend.util.devices import TorchDevice from invokeai.backend.util.devices import TorchDevice
from invokeai.backend.util.silence_warnings import SilenceWarnings from invokeai.backend.util.silence_warnings import SilenceWarnings

View File

@ -20,7 +20,7 @@ from diffusers.utils.import_utils import is_xformers_available
from pydantic import Field from pydantic import Field
from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer from transformers import CLIPFeatureExtractor, CLIPTextModel, CLIPTokenizer
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData from invokeai.backend.stable_diffusion.diffusion.conditioning_data import IPAdapterData, TextConditioningData
from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent from invokeai.backend.stable_diffusion.diffusion.shared_invokeai_diffusion import InvokeAIDiffuserComponent
from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData from invokeai.backend.stable_diffusion.diffusion.unet_attention_patcher import UNetAttentionPatcher, UNetIPAdapterData

View File

@ -6,7 +6,7 @@ from typing import Any, Callable, Optional, Union
import torch import torch
from typing_extensions import TypeAlias from typing_extensions import TypeAlias
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config import get_config
from invokeai.backend.stable_diffusion.diffusion.conditioning_data import ( from invokeai.backend.stable_diffusion.diffusion.conditioning_data import (
IPAdapterData, IPAdapterData,
Range, Range,

View File

@ -3,7 +3,7 @@ from typing import Dict, Literal, Optional, Union
import torch import torch
from deprecated import deprecated from deprecated import deprecated
from invokeai.app.services.config.config_default import get_config from invokeai.app.services.config import get_config
# legacy APIs # legacy APIs
TorchPrecisionNames = Literal["float32", "float16", "bfloat16"] TorchPrecisionNames = Literal["float32", "float16", "bfloat16"]

View File

@ -180,8 +180,7 @@ import urllib.parse
from pathlib import Path from pathlib import Path
from typing import Any, Dict, Optional from typing import Any, Dict, Optional
from invokeai.app.services.config import InvokeAIAppConfig from invokeai.app.services.config import InvokeAIAppConfig, get_config
from invokeai.app.services.config.config_default import get_config
try: try:
import syslog import syslog

View File

@ -12,10 +12,9 @@ from invokeai.app.services.config.config_default import (
CONFIG_SCHEMA_VERSION, CONFIG_SCHEMA_VERSION,
DefaultInvokeAIAppConfig, DefaultInvokeAIAppConfig,
InvokeAIAppConfig, InvokeAIAppConfig,
get_config,
load_and_migrate_config,
) )
from invokeai.app.services.config.config_migrate import ConfigMigrator from invokeai.app.services.config.config_migrate import ConfigMigrator, get_config, load_and_migrate_config
from invokeai.app.services.config.migrations import Migrations, MigrationsBase
from invokeai.app.services.shared.graph import Graph from invokeai.app.services.shared.graph import Graph
from invokeai.frontend.cli.arg_parser import InvokeAIArgs from invokeai.frontend.cli.arg_parser import InvokeAIArgs
@ -72,6 +71,68 @@ i like turtles
""" """
class GoodMigrations(MigrationsBase):
@classmethod
def load(cls, migrator: ConfigMigrator) -> None:
"""Define migrations to perform."""
@migrator.register(from_version="3.0.0", to_version="10.0.0")
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.0", to_version="10.0.1")
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.1", to_version="10.0.2")
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
class BadMigrations1(MigrationsBase):
"""This one fails because there is no path from 10.0.1 to 10.0.2"""
@classmethod
def load(cls, migrator: ConfigMigrator) -> None:
"""Define migrations to perform."""
@migrator.register(from_version="3.0.0", to_version="10.0.0")
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.0", to_version="10.0.1")
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.2", to_version="10.0.3")
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
class BadMigrations2(MigrationsBase):
"""This one fails because the path to 10.0.2 is registered twice"""
@classmethod
def load(cls, migrator: ConfigMigrator) -> None:
"""Define migrations to perform."""
@migrator.register(from_version="3.0.0", to_version="10.0.0")
def migration_1(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.0", to_version="10.0.1")
def migration_2(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.1", to_version="10.0.2")
def migration_3(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@migrator.register(from_version="10.0.0", to_version="10.0.2")
def migration_4(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
@pytest.fixture @pytest.fixture
def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None: def patch_rootdir(tmp_path: Path, monkeypatch: Any) -> None:
"""This may be overkill since the current tests don't need the root dir to exist""" """This may be overkill since the current tests don't need the root dir to exist"""
@ -291,33 +352,26 @@ def test_get_config_writing(patch_rootdir: None, monkeypatch: pytest.MonkeyPatch
def test_migration_check() -> None: def test_migration_check() -> None:
new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"}) # Test the default set of migrations
migrator = ConfigMigrator(Migrations)
new_config = migrator.run_migrations({"schema_version": "4.0.0"})
assert new_config is not None assert new_config is not None
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION
# Does this execute at compile time or run time? # Test a custom set of migrations
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION, to_version=CONFIG_SCHEMA_VERSION + ".1") migrator = ConfigMigrator(GoodMigrations)
def ok_migration(config_dict: dict[str, Any]) -> dict[str, Any]: new_config = migrator.run_migrations({"schema_version": "10.0.0"})
return config_dict assert new_config["schema_version"] == "10.0.2"
new_config = ConfigMigrator.migrate({"schema_version": "4.0.0"}) # Test a migration that should fail validation
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".1" migrator = ConfigMigrator(BadMigrations1)
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".2", to_version=CONFIG_SCHEMA_VERSION + ".3")
def bad_migration(config_dict: dict[str, Any]) -> dict[str, Any]:
return config_dict
# Because there is no version for "*.1" => "*.2", this should fail.
with pytest.raises(ValueError): with pytest.raises(ValueError):
ConfigMigrator.migrate({"schema_version": "4.0.0"}) new_config = migrator.run_migrations({"schema_version": "10.0.0"})
@ConfigMigrator.register(from_version=CONFIG_SCHEMA_VERSION + ".1", to_version=CONFIG_SCHEMA_VERSION + ".2") # Test another bad migration
def good_migration(config_dict: dict[str, Any]) -> dict[str, Any]: migrator = ConfigMigrator(BadMigrations2)
return config_dict with pytest.raises(ValueError):
new_config = migrator.run_migrations({"schema_version": "10.0.0"})
# should work now, because there is a continuous path to *.3
new_config = ConfigMigrator.migrate(new_config)
assert new_config["schema_version"] == CONFIG_SCHEMA_VERSION + ".3"
@contextmanager @contextmanager